TensorRT-LLMs/_modules/tensorrt_llm/runtime/generation.html
2025-11-25 03:40:39 +00:00

5526 lines
885 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en" data-content_root="../../../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.runtime.generation &#8212; TensorRT LLM</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
.pst-js-only { display: none !important; }
</style>
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=8f2a1f02" />
<link rel="stylesheet" type="text/css" href="../../../_static/styles/nvidia-sphinx-theme.css?v=933278ad" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css?v=76b2166b" />
<link rel="stylesheet" type="text/css" href="../../../_static/autodoc_pydantic.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/togglebutton.css?v=13237357" />
<link rel="stylesheet" type="text/css" href="../../../_static/custom.css?v=19d20f17" />
<!-- So that users can add custom icons -->
<script src="../../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
<script>let toggleHintShow = 'Click to show';</script>
<script>let toggleHintHide = 'Click to hide';</script>
<script>let toggleOpenOnPrint = 'true';</script>
<script src="../../../_static/togglebutton.js?v=4a39c7ea"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/runtime/generation';</script>
<script>
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.2.0rc4';
DOCUMENTATION_OPTIONS.show_version_warning_banner =
false;
</script>
<link rel="icon" href="../../../_static/favicon.png"/>
<link rel="index" title="Index" href="../../../genindex.html" />
<link rel="search" title="Search" href="../../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="1.2.0rc4" />
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<dialog id="pst-search-dialog">
<form class="bd-search d-flex align-items-center"
action="../../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form>
</dialog>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
<div class="bd-header__inner bd-page-width">
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
<span class="fa-solid fa-bars"></span>
</button>
<div class="col-lg-3 navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../../index.html">
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-2"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-2"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-2"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-2">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
</div>
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<dialog id="pst-primary-sidebar-modal"></dialog>
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
<a class="navbar-brand logo" href="../../../index.html">
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a>
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-3"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-3"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-3"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-3">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<nav class="bd-docs-nav bd-links"
aria-label="Table of Contents">
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../installation/index.html">Installation</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../installation/containers.html">Pre-built release container images on NGC</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference.html">Generate text</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_sparse_attention.html">Sparse Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_kv_cache_offloading.html">KV Cache Offloading</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client.html">Curl Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_completion_client.html">Curl Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client.html">Genai Perf Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client_for_multimodal.html">Genai Perf Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
</ul>
</details></li>
<li class="toctree-l1"><a class="reference internal" href="../../../examples/dynamo_k8s_example.html">Dynamo K8s Example</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../deployment-guide/index.html">Model Recipes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-deepseek-r1-on-trtllm.html">Deployment Guide for DeepSeek R1 on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-llama3.3-70b-on-trtllm.html">Deployment Guide for Llama3.3 70B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-llama4-scout-on-trtllm.html">Deployment Guide for Llama4 Scout 17B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-gpt-oss-on-trtllm.html">Deployment Guide for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-qwen3-next-on-trtllm.html">Deployment Guide for Qwen3 Next on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Models</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../models/supported-models.html">Supported Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../models/adding-new-model.html">Adding a New Model</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">CLI Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-bench.html">trtllm-bench</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-eval.html">trtllm-eval</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/index.html">LLM API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/reference.html">API Reference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Features</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../features/feature-combination-matrix.html">Feature Combination Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/disagg-serving.html">Disaggregated Serving</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/kvcache.html">KV Cache System</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/long-sequence.html">Long Sequences</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/lora.html">LoRA (Low-Rank Adaptation)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/multi-modality.html">Multimodal Support in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/overlap-scheduler.html">Overlap Scheduler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/paged-attention-ifb-scheduler.html">Paged Attention, IFB, and Request Scheduling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/parallel-strategy.html">Parallelism in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/sampling.html">Sampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/additional-outputs.html">Additional Outputs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/speculative-decoding.html">Speculative Decoding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/checkpoint-loading.html">Checkpoint Loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/auto_deploy/auto-deploy.html">AutoDeploy (Prototype)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/ray-orchestrator.html">Ray Orchestrator (Prototype)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../features/torch_compile_and_piecewise_cuda_graph.html">Torch Compile &amp; Piecewise CUDA Graph</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/overview.html">Architecture Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/perf-analysis.html">Performance Analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/perf-benchmarking.html">TensorRT LLM Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/ci-overview.html">Continuous Integration Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/dev-containers.html">Using Dev Containers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/api-change.html">LLM API Change Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/kv-transfer.html">Introduction to KV Cache Transmission</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog11_GPT_OSS_Eagle3.html">Running GPT-OSS-120B with Eagle3 Speculative Decoding on GB200/B200 (TensorRT LLM)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog12_Combining_Guided_Decoding_and_Speculative_Decoding.html">Combining Guided Decoding and Speculative Decoding: Making CPU and GPU Cooperate Seamlessly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog13_Inference_Time_Compute_Implementation_in_TensorRT-LLM.html">Inference Time Compute Implementation in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html">Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT LLM server</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-GramSpeculativeDecodingin TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT LLM (Part 2: Performance Status and Optimization)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.html">How to get best performance on DeepSeek-R1 in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Quick Links</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/releases">Releases</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM">Github Code</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap">Roadmap</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumb" class="d-print-none">
<ul class="bd-breadcrumbs">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.runtime.generation</span></li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<h1>Source code for tensorrt_llm.runtime.generation</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">copy</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">platform</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">Counter</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">dataclasses</span><span class="w"> </span><span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">reduce</span><span class="p">,</span> <span class="n">wraps</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Set</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="c1"># isort: off</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
<span class="c1"># isort: on</span>
<span class="k">try</span><span class="p">:</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">cuda.bindings</span><span class="w"> </span><span class="kn">import</span> <span class="n">runtime</span> <span class="k">as</span> <span class="n">cudart</span>
<span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">cuda</span><span class="w"> </span><span class="kn">import</span> <span class="n">cudart</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.memory_pools.memory_pools_allocator</span><span class="w"> </span><span class="kn">import</span> \
<span class="n">MemoryPoolsAllocator</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager</span><span class="w"> </span><span class="kn">import</span> \
<span class="n">PoolsKVCacheManager</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.redrafter_utils</span><span class="w"> </span><span class="kn">import</span> <span class="o">*</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">binding_layer_type_to_str</span><span class="p">,</span> <span class="n">binding_to_str_dtype</span><span class="p">,</span>
<span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span> <span class="n">torch_to_numpy</span><span class="p">,</span>
<span class="n">trt_dtype_to_torch</span><span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..bindings</span><span class="w"> </span><span class="kn">import</span> <span class="n">ipc_nvls_allocate</span><span class="p">,</span> <span class="n">ipc_nvls_free</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..layers</span><span class="w"> </span><span class="kn">import</span> <span class="n">LanguageAdapterConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..llmapi.kv_cache_type</span><span class="w"> </span><span class="kn">import</span> <span class="n">KVCacheType</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..logger</span><span class="w"> </span><span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..lora_manager</span><span class="w"> </span><span class="kn">import</span> <span class="n">LoraManager</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..mapping</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mapping</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..plugin.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">CustomAllReduceHelper</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">..quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantMode</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.kv_cache_manager</span><span class="w"> </span><span class="kn">import</span> <span class="n">GenerationSequence</span><span class="p">,</span> <span class="n">KVCacheUpdater</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.session</span><span class="w"> </span><span class="kn">import</span> <span class="n">_scoped_stream</span>
<span class="c1"># When variable is set, this will disable torch.cuda.set_device(...) calls</span>
<span class="c1"># Useful in situations where device is already assigned by another library, i.e., megatron.</span>
<span class="n">DISABLE_TORCH_DEVICE_SET</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;DISABLE_TORCH_DEVICE_SET&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<div class="viewcode-block" id="decode_words_list">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.decode_words_list">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">decode_words_list</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
<span class="n">tokenizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> format of word_dict</span>
<span class="sd"> len(word_dict) should be same to batch_size</span>
<span class="sd"> word_dict[i] means the words for batch i</span>
<span class="sd"> len(word_dict[i]) &gt;= 1, which means it must contain at least 1 string</span>
<span class="sd"> For example, word_dict[2] = [&quot; I am happy&quot;, &quot; I am sad&quot;].</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">tokenizer</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;need to set tokenizer&quot;</span>
<span class="n">decoded_words_batch</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
<span class="n">decoded_words_request</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">word_dict_item</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">item</span><span class="p">,</span> <span class="nb">bytes</span><span class="p">):</span>
<span class="n">item</span> <span class="o">=</span> <span class="p">[</span><span class="n">item</span><span class="o">.</span><span class="n">decode</span><span class="p">()]</span>
<span class="n">ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">item</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">add_special_tokens</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">decoded_words_request</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span>
<span class="n">decoded_words_batch</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">decoded_words_request</span><span class="p">)</span>
<span class="k">return</span> <span class="n">decoded_words_batch</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">to_word_list_format</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> format of word_dict</span>
<span class="sd"> len(word_dict) should be same to batch_size</span>
<span class="sd"> word_dict[i] means the words for batch i</span>
<span class="sd"> len(word_dict[i]) &gt;= 1, which means it must contain at least 1 word</span>
<span class="sd"> For example, word_dict[2] = [[1, 267], [534]] has two words.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
<span class="n">items_flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">items_offsets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">word_dict_item</span><span class="p">:</span>
<span class="n">items_flat_ids</span> <span class="o">+=</span> <span class="n">ids</span>
<span class="n">items_offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">))</span>
<span class="n">flat_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">items_flat_ids</span><span class="p">))</span>
<span class="n">offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">items_offsets</span><span class="p">)))</span>
<span class="n">pad_to</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">flat_ids</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">offs</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)):</span>
<span class="n">flat_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">offs</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">offs</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;int32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_input_ids</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">tensors</span><span class="p">)</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
<span class="n">row_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">row_lengths</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">row_lengths</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">CUASSERT</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;CUDA ERROR: </span><span class="si">{</span><span class="n">err</span><span class="si">}</span><span class="s2">, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_update_cuda_graph_instance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="n">err</span> <span class="o">=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecUpdate</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">)</span>
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
<span class="c1"># When updating cuda graph failed, destroy and instantiate one.</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecDestroy</span><span class="p">(</span><span class="n">instance</span><span class="p">))</span>
<span class="n">instance</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">instance</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">is_pad_id_in_inputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
<span class="k">if</span> <span class="n">input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">is_pad_id_in_inputs</span><span class="p">:</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="c1"># for enc-dec models, pad_id could be the start token and should be always counted</span>
<span class="c1"># as valid token rather than padded token, so we force its mask to be 1.</span>
<span class="c1"># This doesn&#39;t impact the existing behavior</span>
<span class="n">mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">mask</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="k">return</span> <span class="n">new_tensor</span>
<span class="k">class</span><span class="w"> </span><span class="nc">_Profiler</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">IProfiler</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">report_layer_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">,</span> <span class="n">ms</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">results</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">layer_name</span><span class="p">,</span> <span class="n">ms</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_contiguous_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*=</span> <span class="n">num_beams</span>
<span class="n">numel</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">num_beams</span> <span class="o">*</span> <span class="n">numel</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># Take the first &#39;size&#39; values to tile and skip the others.</span>
<span class="n">vals</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">size</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_beams</span><span class="p">):</span>
<span class="n">new_tensor</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]</span> <span class="o">=</span> <span class="n">vals</span>
<span class="k">return</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">_Runtime</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span>
<span class="n">engine</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ICudaEngine</span>
<span class="n">ctx_context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_0</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">context_1</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
<span class="n">profiler</span><span class="p">:</span> <span class="n">_Profiler</span>
<span class="n">engine_inspector</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">EngineInspector</span>
<span class="n">cuda_graph_instances</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExec_t</span><span class="p">]</span>
<span class="n">input_tensor_names</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
<span class="n">output_tensor_names</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare</span><span class="p">(</span><span class="n">mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_serialize_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__create_and_setup_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">address</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">profile_idx</span><span class="p">,</span>
<span class="n">stream</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context_without_device_memory</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Failed to create an execution context with the provided device memory!&quot;</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_device_memory</span><span class="p">(</span><span class="n">address</span><span class="p">,</span> <span class="n">size</span><span class="p">)</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_optimization_profile_async</span><span class="p">(</span><span class="n">profile_idx</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="c1"># If nvtx verbosity is DETAILED, change it to LAYER_NAMES_ONLY for inference performance</span>
<span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">nvtx_verbosity</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">DETAILED</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">nvtx_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">LAYER_NAMES_ONLY</span>
<span class="k">return</span> <span class="n">context</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_set_profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="n">_Profiler</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
<span class="n">local_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">%</span> <span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
<span class="k">if</span> <span class="n">DISABLE_TORCH_DEVICE_SET</span><span class="p">:</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_device</span><span class="p">()))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">local_rank</span><span class="p">)</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">local_rank</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">OUTPUT</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine_inspector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_engine_inspector</span><span class="p">()</span>
<span class="c1"># cuda graph ping-pong instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span><span class="p">:</span>
<span class="c1"># engine does not have weight streaming enabled</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare_execution_contexts</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># avoid OOM when print engine info</span>
<span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="s2">&quot;verbose&quot;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__print_engine_info</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__prepare_execution_contexts</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># The device_memory_size_v2 stores the memory required by the largest profile.</span>
<span class="c1"># When weight streaming is enable, it must be queried after the weight streaming budget set.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">))</span>
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># At step = 0, context_1 is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="c1"># At step = 0, ctx_context is active</span>
<span class="c1"># At step = 1, context_0 is active</span>
<span class="c1"># At step = 2, context_1 is active</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Number of optimization profiles: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
<span class="s2">&quot;Python runtime only support 1 or 2 optimization profiles, &quot;</span>
<span class="s2">&quot;set --multiple_profiles=disable when calling trtllm-build &quot;</span>
<span class="s2">&quot;to disable the feature.&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__print_engine_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span>
<span class="n">context</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ExecutionContextAllocationStrategy</span><span class="o">.</span><span class="n">USER_MANAGED</span><span class="p">)</span>
<span class="n">n_op</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span>
<span class="n">max_name_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Name</span>
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Shape</span>
<span class="n">tensor_name_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="c1"># Get information of engine input / output</span>
<span class="n">tid</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Tensor Information Dictionary</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
<span class="n">item</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">max_name_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_name_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;mode&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;I&#39;</span> <span class="k">if</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span>
<span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span> <span class="k">else</span> <span class="s1">&#39;O&#39;</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;location&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;GPU&#39;</span> <span class="k">if</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_location</span><span class="p">(</span>
<span class="n">name</span><span class="p">)</span> <span class="k">else</span> <span class="s1">&#39;CPU&#39;</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;data_type&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))[</span><span class="mi">9</span><span class="p">:]</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;build_shape&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;profile_list&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">)]</span>
<span class="k">if</span> <span class="n">item</span><span class="p">[</span><span class="s2">&quot;mode&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;I&quot;</span><span class="p">:</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
<span class="k">if</span> <span class="n">item</span><span class="p">[</span><span class="s2">&quot;location&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;GPU&quot;</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_value</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
<span class="n">item</span><span class="p">[</span><span class="s2">&quot;profile_list&quot;</span><span class="p">][</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_shape_width</span><span class="p">,</span>
<span class="o">*</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">s</span><span class="p">))</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">])</span>
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">item</span>
<span class="c1"># Set input shape to get output shape</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="c1"># Min, Opt, Max</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tid</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">if</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;mode&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;I&quot;</span><span class="p">:</span>
<span class="k">if</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;location&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;GPU&quot;</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;profile_list&quot;</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;profile_list&quot;</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;mode&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;O&quot;</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">context</span><span class="o">.</span><span class="n">all_binding_shapes_specified</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">all_shape_inputs_specified</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">&quot;profile_list&quot;</span><span class="p">][</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># Print information of engine input / output</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s2">&quot;Information of engine input / output.&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;Name&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|I/O|Location|DataType|</span><span class="si">{</span><span class="s1">&#39;Shape&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;-&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
<span class="n">item</span> <span class="o">=</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2">&lt;</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;mode&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">^3s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;location&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">^8s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;data_type&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">^8s</span><span class="si">}</span><span class="s2">|&quot;</span>
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;build_shape&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="c1"># Print information of optimization profile</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s2">&quot;Information of optimization profile.&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Optimization Profile </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">:&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;Name&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">&#39;Min&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">&#39;Opt&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">&#39;Max&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;-&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
<span class="n">item</span> <span class="o">=</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2">&lt;</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;profile_list&#39;</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;profile_list&#39;</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">&#39;profile_list&#39;</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">2</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">print_context_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">context_index</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">n_io</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span>
<span class="n">max_name_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Name</span>
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Shape</span>
<span class="n">tensorInfo</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_io</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">b_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span>
<span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span>
<span class="n">shape</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="n">tensorInfo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">name</span><span class="p">,</span> <span class="n">b_input</span><span class="p">,</span> <span class="n">shape</span><span class="p">]</span>
<span class="n">max_name_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_name_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_shape_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
<span class="c1"># Shape input tensor is not used in TRT-LLM yet</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Information of context input / output.&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Using Optimization Profile: </span><span class="si">{</span><span class="n">context_index</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;Name&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|I/O|</span><span class="si">{</span><span class="s1">&#39;Shape&#39;</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;-&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_io</span><span class="p">):</span>
<span class="n">name</span><span class="p">,</span> <span class="n">b_input</span><span class="p">,</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">tensorInfo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2">&lt;</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">&#39;I&#39;</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">b_input</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s1">&#39;O&#39;</span><span class="si">:</span><span class="s2">^3s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">shape</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|&quot;</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="s1">&#39;=&#39;</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_set_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">shape_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">shape_dict</span><span class="p">:</span>
<span class="c1"># shape and buffer can be set by calling _set_tensors API</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;setting input tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2"> and type </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Couldn&#39;t assign </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;engine supports [min, opt, max] = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">context</span><span class="o">.</span><span class="n">active_optimization_profile</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_set_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">buffer_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">buffer_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is not contiguous()&quot;</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">())</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_set_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">&quot;RuntimeTensor&quot;</span><span class="p">]):</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span><span class="p">:</span>
<span class="c1"># it&#39;s allowed to call set_tensors multi times with different tensors</span>
<span class="c1"># each time only set some of the engine tensors, so it is valid to skip the ones not in the current given tensors dict</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">!=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">data</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">list</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span> <span class="o">!=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">))</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="c1"># output&#39;s shape is inference by TRT, no need to set the shape here</span>
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_set_weight_streaming</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gpu_weights_percent</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">gpu_weights_percent</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Engine built without weight streaming. Cannot set gpu_weights_percent to a value other than 1.&quot;</span>
<span class="k">return</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="kc">None</span>
<span class="nb">min</span> <span class="o">=</span> <span class="mi">0</span>
<span class="nb">max</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span>
<span class="n">budget</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">gpu_weights_percent</span> <span class="o">*</span> <span class="nb">max</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">=</span> <span class="n">budget</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">==</span> <span class="n">budget</span><span class="p">,</span> <span class="s2">&quot;Failed to set weight streaming budget!&quot;</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Set gpu weights percent to </span><span class="si">{</span><span class="n">gpu_weights_percent</span><span class="si">}</span><span class="s2">, which is </span><span class="si">{</span><span class="n">budget</span><span class="si">}</span><span class="s2"> bytes. Valid range: </span><span class="si">{</span><span class="nb">min</span><span class="si">}</span><span class="s2"> bytes ~ </span><span class="si">{</span><span class="nb">max</span><span class="si">}</span><span class="s2"> bytes.&quot;</span>
<span class="p">)</span>
<span class="k">try</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare_execution_contexts</span><span class="p">()</span>
<span class="k">except</span><span class="p">:</span>
<span class="n">free_mem</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">mem_get_info</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">free_mem</span> <span class="o">&lt;</span> <span class="n">budget</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Failed to create context. Possibly out of memory: Memory budget is </span><span class="si">{</span><span class="n">budget</span><span class="si">}</span><span class="s2"> bytes but only </span><span class="si">{</span><span class="n">free_mem</span><span class="si">}</span><span class="s2"> bytes are available on the GPU.&quot;</span>
<span class="p">)</span>
<span class="k">raise</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_check_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">ptr</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ptr</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Engine I/O tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is unbound&quot;</span><span class="p">)</span>
<span class="n">shp</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">s</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">shp</span><span class="p">]):</span> <span class="c1"># skip if shape is not available</span>
<span class="k">continue</span>
<span class="n">dt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">tdt</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dt</span><span class="p">)</span>
<span class="n">sz</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tdt</span><span class="p">)</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">shp</span><span class="p">)</span>
<span class="n">tensors</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">ptr</span><span class="p">,</span> <span class="n">ptr</span> <span class="o">+</span> <span class="n">sz</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shp</span><span class="p">,</span> <span class="n">sz</span><span class="p">))</span>
<span class="n">tensors</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span> <span class="c1"># sort by start address</span>
<span class="n">starts</span><span class="p">,</span> <span class="n">ends</span><span class="p">,</span> <span class="n">names</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)</span>
<span class="n">starts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">starts</span><span class="p">)</span>
<span class="n">ends</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">ends</span><span class="p">)</span>
<span class="n">overalps</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">((</span><span class="n">starts</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">&lt;</span> <span class="n">ends</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
<span class="k">if</span> <span class="n">overalps</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># unsqueeze if there is a single value so it became scalar</span>
<span class="n">overalps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">overalps</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="n">overalps</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">overalps</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">overalps</span><span class="p">):</span>
<span class="n">left_name</span> <span class="o">=</span> <span class="n">names</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">right_name</span> <span class="o">=</span> <span class="n">names</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="s2">&quot;key_value&quot;</span> <span class="ow">in</span> <span class="n">left_name</span> <span class="ow">and</span> <span class="s2">&quot;key_value&quot;</span> <span class="ow">in</span> <span class="n">right_name</span><span class="p">:</span> <span class="c1"># kv</span>
<span class="n">left_names</span> <span class="o">=</span> <span class="n">left_name</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)</span>
<span class="n">right_names</span> <span class="o">=</span> <span class="n">right_name</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left_names</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">right_names</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> <span class="c1"># same kv layer</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">left_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;past&quot;</span> <span class="ow">and</span> <span class="n">right_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;present&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
<span class="n">left_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;present&quot;</span> <span class="ow">and</span> <span class="n">right_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;past&quot;</span><span class="p">),</span> \
<span class="sa">f</span><span class="s2">&quot;Overlap found between </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">continue</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;TENSOR BUFFER OVERLAP DETECTED: </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> !!!&quot;</span>
<span class="p">)</span>
<span class="k">return</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_insert_step_to_profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;Profiler is disable&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span><span class="o">.</span><span class="n">results</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="s2">&quot;step&quot;</span><span class="p">,</span> <span class="n">step</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_is_profiling</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">):</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ok</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
<span class="k">pass</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">context_mem_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
<div class="viewcode-block" id="ModelConfig">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span><span class="w"> </span><span class="nc">ModelConfig</span><span class="p">:</span>
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">gpt_attention_plugin</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">gemm_allreduce_plugin</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">kv_cache_type</span><span class="p">:</span> <span class="n">KVCacheType</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">CONTINUOUS</span>
<span class="n">cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">has_position_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">has_token_type_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">lora_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
<span class="n">trtllm_modules_to_hf_modules</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">skip_cross_kv</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">num_medusa_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">max_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">paged_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">mamba_conv1d_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">conv_kernel</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">layer_types</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
<span class="n">rnn_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">rnn_head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">rnn_conv_dim_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">state_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">state_dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="n">gpu_weights_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="c1"># ReDrafter</span>
<span class="n">redrafter_num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">redrafter_draft_len_per_beam</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">num_kv_heads_per_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">num_kv_heads_per_cross_attn_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># language adapter</span>
<span class="n">language_adapter_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LanguageAdapterConfig</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<div class="viewcode-block" id="ModelConfig.from_model_config_cpp">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig.from_model_config_cpp">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_model_config_cpp</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">model_config_cpp</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;ModelConfig&#39;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance.</span>
<span class="sd"> Note that each of these classes have fields that don&#39;t exist in the other, so the created ModelConfigPython</span>
<span class="sd"> won&#39;t have all of its fields initialized.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">max_batch_size</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">max_beam_width</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="n">vocab_size</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="n">num_layers</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">num_layers</span><span class="p">(),</span>
<span class="n">num_heads</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span>
<span class="n">num_kv_heads</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">num_kv_heads</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">hidden_size</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">use_packed_input</span><span class="p">,</span>
<span class="n">kv_cache_type</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">cross_attention</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">use_cross_attention</span><span class="p">,</span>
<span class="n">head_size</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">max_prompt_embedding_table_size</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">,</span>
<span class="n">quant_mode</span><span class="o">=</span><span class="n">QuantMode</span><span class="p">(</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">value</span><span class="p">),</span>
<span class="n">gather_context_logits</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">compute_context_logits</span><span class="p">,</span>
<span class="n">gather_generation_logits</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">compute_generation_logits</span><span class="p">,</span>
<span class="n">gpt_attention_plugin</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">binding_to_str_dtype</span><span class="p">(</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">data_type</span><span class="p">),</span>
<span class="n">num_kv_heads_per_layer</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">lora_plugin</span><span class="o">=</span><span class="n">model_config_cpp</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">,</span>
<span class="n">layer_types</span><span class="o">=</span><span class="p">[</span>
<span class="n">binding_layer_type_to_str</span><span class="p">(</span><span class="n">lt</span><span class="p">)</span>
<span class="k">for</span> <span class="n">lt</span> <span class="ow">in</span> <span class="n">model_config_cpp</span><span class="o">.</span><span class="n">layer_types</span>
<span class="p">],</span>
<span class="p">)</span></div>
</div>
<div class="viewcode-block" id="SamplingConfig">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span><span class="w"> </span><span class="nc">SamplingConfig</span><span class="p">:</span>
<span class="n">end_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">num_return_sequences</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">stop_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">bad_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">temperature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">top_k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">top_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">top_p_decay</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
<span class="n">top_p_min</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
<span class="n">top_p_reset_ids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># int</span>
<span class="n">random_seed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">length_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">early_stopping</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">repetition_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="n">min_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">presence_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">frequency_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">prompt_ignore_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">use_beam_hyps</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># None here means user didn&#39;t set it, and dynamicDecodeOp.cpp take optional value</span>
<span class="c1"># The real default value is set in dynamicDecodeOp.cpp when it&#39;s None</span>
<span class="n">beam_search_diversity_rate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<span class="n">output_cum_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">output_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">no_repeat_ngram_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">min_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
<div class="viewcode-block" id="SamplingConfig.update">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig.update">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">unused_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">unused_kwargs</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">unused_kwargs</span></div>
</div>
<div class="viewcode-block" id="LogitsProcessor">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessor">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">LogitsProcessor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Base class for all logit processors that can be applied during generation.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2"> is an abstract class. Only classes inheriting this class can be called.&quot;</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="LogitsProcessorList">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessorList">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">LogitsProcessorList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">LogitsProcessor</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">for</span> <span class="n">processor</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">:</span>
<span class="n">scores</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span>
<span class="k">return</span> <span class="n">scores</span></div>
<div class="viewcode-block" id="StoppingCriteria">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteria">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">StoppingCriteria</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Base class for all stopping criteria that can be applied during generation.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;StoppingCriteria needs to be subclassed&quot;</span><span class="p">)</span></div>
<div class="viewcode-block" id="StoppingCriteriaList">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteriaList">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">StoppingCriteriaList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">StoppingCriteria</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">criteria</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span> <span class="k">for</span> <span class="n">criteria</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">)</span></div>
<span class="k">class</span><span class="w"> </span><span class="nc">RuntimeTensor</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
<span class="c1"># shape is the one sent to TRT, the actual torch tensor can be larger than the shape</span>
<span class="c1"># this is useful when allocating a big KV cache tensor at the beginning and incremental seq length dim of TRT engine&#39;s input tensor</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># Used when pointer specified</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="kc">None</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span>
<span class="n">str_dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;RuntimeTensor&#39;</span><span class="p">:</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
<span class="n">t</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="n">pointer</span>
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="n">t</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="n">str_dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">t</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_torch</span><span class="p">(</span>
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">override_shape</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s1">&#39;RuntimeTensor&#39;</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)),</span> <span class="sa">f</span><span class="s2">&quot;data </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
<span class="c1"># need to hold the torch tensor for memory life time</span>
<span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">t</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">t</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()</span>
<span class="n">torch_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>
<span class="k">if</span> <span class="n">override_shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">override_shape</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">override_shape</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">override_shape</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">override_shape</span>
<span class="p">]),</span> <span class="sa">f</span><span class="s2">&quot;Expect all dimensions &gt;=0, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">volume_func</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">override_shape</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">torch_shape</span><span class="p">),</span> \
<span class="sa">f</span><span class="s2">&quot;Override the shape to be larger than the underlying torch Tensor, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">, torch tensor shape </span><span class="si">{</span><span class="n">torch_shape</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">torch_shape</span>
<span class="k">return</span> <span class="n">t</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s1">&#39;RuntimeTensor cannot be converted to torch tensor as constructed from pointer&#39;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data_ptr</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span>
<div class="viewcode-block" id="GenerationSession">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">GenerationSession</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="n">_model_config</span><span class="p">:</span> <span class="n">ModelConfig</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span>
<span class="n">runtime</span><span class="p">:</span> <span class="n">_Runtime</span>
<span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">buffer_allocated</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span>
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span>
<span class="n">debug_tensors_to_save</span><span class="p">:</span> <span class="kc">None</span>
<span class="n">num_draft_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">medusa_topks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_paths</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_tree_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_position_offsets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">medusa_temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
<span class="k">if</span> <span class="n">DISABLE_TORCH_DEVICE_SET</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;cuda:</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_device</span><span class="p">()</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
<span class="c1"># Optional inputs for dynamic decoder</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it&#39;s T, can be float or half?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;attention&#39;</span><span class="p">]</span> <span class="o">*</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">layer_types</span>
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span> <span class="o">*</span> <span class="p">(</span><span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span>
<span class="nb">len</span><span class="p">(</span><span class="n">layer_types</span><span class="p">))</span>
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span> <span class="o">+</span> <span class="n">layer_types</span><span class="p">[</span><span class="mi">0</span><span class="p">:(</span><span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span>
<span class="nb">len</span><span class="p">(</span><span class="n">layer_types</span><span class="p">))]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">=</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">&#39;attention&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span> <span class="o">=</span> <span class="s1">&#39;recurrent&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span> <span class="o">=</span> <span class="p">{}</span>
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">attn_layer_idx</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="n">attn_layer_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn_layer_idx</span>
<span class="n">attn_layer_idx</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="c1"># Cyclic KV cache buffer names.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">NcclCommunicatorOp</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span> <span class="o">=</span> <span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">allocate_workspace</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
<span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">max_workspace_size_auto</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;input_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_input&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;hidden_states_output&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;token_type_ids&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;cache_indirection&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;kv_cache_block_offsets&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_offsets&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_pointers&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_mapping&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_kv_cache_block_offsets&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_block_offsets&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_pointers&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_mapping&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_attention_packed_mask&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Refer to gpt_attention() inside functional.py</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;cross_attention_packed_mask&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;recurrent&#39;</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;slot_mapping&#39;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;recurrent&#39;</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;past_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span> <span class="s1">&#39;host_past_key_value_lengths&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;context_lengths&#39;</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">,</span>
<span class="s1">&#39;host_sink_token_length&#39;</span><span class="p">,</span> <span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">,</span>
<span class="s1">&#39;host_context_progress&#39;</span>
<span class="p">]</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;host_request_types&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;encoder_output&#39;</span><span class="p">,</span>
<span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">,</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">,</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;skip_cross_attn_blocks&#39;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_kv</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;cross_kv_reuse&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span> <span class="o">=</span> <span class="n">LoraManager</span><span class="o">.</span><span class="n">get_missing_qkv_modules</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
<span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">,</span>
<span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">,</span> <span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">,</span>
<span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">,</span> <span class="s1">&#39;medusa_logits&#39;</span>
<span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="n">get_redrafter_tensor_names</span><span class="p">()</span>
<span class="c1"># language adapter</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">language_adapter_config</span><span class="p">:</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">&#39;language_adapter_routings&#39;</span><span class="p">]</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;allreduce_ub_&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span>
<span class="s2">&quot;gemm_allreduce&quot;</span><span class="p">):</span>
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="n">name</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
<span class="n">found_tensor_names</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Tensor names in engine are not the same as expected, to use this GenerationSession, &quot;</span>
<span class="s2">&quot;you need to use PretrainedModel.prepare_inputs to create TRT Network inputs.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Debug tensors found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Debug tensors to save: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">ipc_nvls_free</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
<span class="k">pass</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">context_mem_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_mem_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">vocab_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">num_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;num_layers </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="si">}</span><span class="s2"> must be a multiple of pipeline parallelism size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">last_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">num_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_heads</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># For linear layer in attention block</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">use_gpt_attention_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">use_mamba_conv1d_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">paged_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">kv_cache_type</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">use_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">DISABLED</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">tokens_per_block</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<div class="viewcode-block" id="GenerationSession.get_num_heads_kv">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_num_heads_kv">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">get_num_heads_kv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">if</span> <span class="n">layer_idx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">:</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
<span class="n">layer_idx</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;attention&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Layer </span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2"> is not an attention layer&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">]</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span></div>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">max_prompt_embedding_table_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">quant_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">quant_mode</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">gather_context_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_context_logits</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">gather_generation_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_generation_logits</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">profiler</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">engine_inspector</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine_inspector</span>
<div class="viewcode-block" id="GenerationSession.cuda_stream_guard">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.cuda_stream_guard">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cuda_stream_guard</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Sync external stream and set current stream to the one bound to the session. Reset on exit.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@wraps</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">external_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="n">external_stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="n">external_stream</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">return</span> <span class="n">wrapper</span></div>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">cross_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">cross_attention</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">has_position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_position_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">has_token_type_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">lora_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">use_gemm_allreduce_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="nb">bool</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_allreduce_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">is_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">is_redrafter_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">max_draft_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">num_medusa_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">paged_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_state</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">conv_kernel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">conv_kernel</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_hidden_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_head_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_conv_dim_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">state_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_size</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">state_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_dtype</span> <span class="o">==</span> <span class="s2">&quot;&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_dtype</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="c1"># Create two cuda graph once.If cuda graph has already existed, skip it.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span>
<span class="c1"># capture cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamBeginCapture</span><span class="p">(</span>
<span class="n">stream</span><span class="p">,</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamCaptureMode</span><span class="o">.</span><span class="n">cudaStreamCaptureModeGlobal</span><span class="p">))</span>
<span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
<span class="n">next_graph</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamEndCapture</span><span class="p">(</span><span class="n">stream</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_update_cuda_graph_instance</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">next_graph</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">next_graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="c1"># Pre-upload cuda graph to stream</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphUpload</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">__setup_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Allocate buffers and setup the post-processing decoder kernel</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span> <span class="c1"># just to make a shorter name, no other meaning</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_k.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.top_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.temperature.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.repetition_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.length_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.length_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.early_stopping.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.early_stopping.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.presence_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.frequency_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.frequency_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.prompt_ignore_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.prompt_ignore_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prompt_ignore_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prompt_ignore_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.beam_search_diversity_rate.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int64&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.random_seed.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.no_repeat_ngram_size.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.no_repeat_ngram_size.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;scfg.min_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">prompt_ignore_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;end_id cannot be none&quot;</span>
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;pad_id cannot be none&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
<span class="c1"># setup output ids buffer</span>
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># input_ids only have one dimension, which means remove_padding is enabled</span>
<span class="n">split_ids_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">split_ids_list</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">input_ids</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">)</span>
<span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># TODO: delete?</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">tiled_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Note: we still allocate max_seq_length size of parent ids (not max_attention_window_size).</span>
<span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span> <span class="o">+</span> <span class="mi">1</span>
<span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s2">&quot;redrafter_inverted_temperature&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">reciprocal</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="o">-</span><span class="mf">1e20</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
<span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">fill_value</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_tensor_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="c1"># return torch dtype given tensor name for convenience</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
<span class="k">return</span> <span class="n">dtype</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init_medusa</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.medusa_utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">_medusa_setup</span><span class="p">,</span>
<span class="n">expand_choices_if_needed</span><span class="p">)</span>
<span class="n">medusa_choices</span> <span class="o">=</span> <span class="n">expand_choices_if_needed</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span>
<span class="n">medusa_info</span> <span class="o">=</span> <span class="n">_medusa_setup</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_topks</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span>
<span class="p">)</span> <span class="c1"># convert to bool, original mask includes true token as well</span>
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
<span class="c1"># Note: spec_decoding_packed_mask has no paddings in the first dimension.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">target_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_spec_decoding_use</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_paths</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_tree_ids</span>
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
<span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
<span class="c1"># Note: medusa_position_offsets still keeps the paddings in order to get max_gen_input_length from the shape info.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># Fixed sequence lengths currently.</span>
<span class="c1"># Support variable sequence lengths later.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">medusa_fp_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">medusa_fp_mask</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_fp_mask</span>
<span class="k">return</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_num_paged_blocks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="n">sink_token_length</span><span class="p">):</span>
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
<span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
<span class="p">(</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
<span class="n">num_blocks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_blocks_per_seq</span>
<span class="k">return</span> <span class="n">num_blocks</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span>
<div class="viewcode-block" id="GenerationSession.setup">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.setup">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">multi_block_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># Store these params related to buffer size to check against</span>
<span class="c1"># the input shape with the params given in decode()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span> <span class="o">=</span> <span class="n">encoder_max_input_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span> <span class="o">=</span> <span class="n">multi_block_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span> <span class="o">=</span> <span class="n">enable_context_fmha_fp32_acc</span>
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="s2">&quot;The max_attention_window_size is not set, we will use max_seq_length by default.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_attention_window_size should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="n">max_attention_window_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="n">attn_win_size_len</span> <span class="o">=</span> <span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">num_total_attn_layers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">&#39;attention&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">attn_win_size_len</span> <span class="o">&lt;</span> <span class="n">num_total_attn_layers</span><span class="p">:</span>
<span class="n">repeat_num</span> <span class="o">=</span> <span class="n">num_total_attn_layers</span> <span class="o">//</span> <span class="n">attn_win_size_len</span>
<span class="n">remain_num</span> <span class="o">=</span> <span class="n">num_total_attn_layers</span> <span class="o">%</span> <span class="n">attn_win_size_len</span>
<span class="n">warning_info</span> <span class="o">=</span> <span class="s2">&quot;The size of max_attention_window_size tensor/list is less than num_attn_layers, &quot;</span> \
<span class="o">+</span> <span class="s2">&quot;and it will be repeated to num_attn_layers. So the actual max_attention_window_size &quot;</span> \
<span class="o">+</span> <span class="sa">f</span><span class="s2">&quot;is </span><span class="si">{</span><span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span><span class="si">}</span><span class="s2"> * </span><span class="si">{</span><span class="n">repeat_num</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">warning_info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot; + </span><span class="si">{</span><span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()[</span><span class="mi">0</span><span class="p">:</span><span class="n">remain_num</span><span class="p">]</span><span class="si">}</span><span class="s2">. &quot;</span> <span class="k">if</span> <span class="n">remain_num</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="s2">&quot;. &quot;</span>
<span class="n">warning_info</span> <span class="o">+=</span> <span class="s2">&quot;Note that num_attn_layers is the number of total attention layers.&quot;</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">warning_info</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">attn_win_size_len</span> <span class="o">&gt;</span> <span class="n">num_total_attn_layers</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s2">&quot;The size of max_attention_window_size tensor/list is larger than num_attn_layers! &quot;</span>
<span class="s2">&quot;Note that num_attn_layers is the number of total attention layers.&quot;</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="kc">False</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;The value of max_attention_window_size should ideally not exceed max_seq_length. &quot;</span>
<span class="s2">&quot;Therefore, it has been adjusted to match the value of max_seq_length.&quot;</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span>
<span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">]</span> <span class="o">*</span> <span class="n">attn_win_size_len</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">[</span>
<span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">max_attention_window_size</span><span class="p">[</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">]</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">&#39;attention&#39;</span><span class="p">)</span>
<span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="o">%</span> <span class="n">attn_win_size_len</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;invalid max_attention_window_size!&quot;</span>
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sink_token_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="n">sink_token_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;invalid sink_token_length!&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_init_medusa</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">init_allocate_redrafter_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">medusa_logits_shape</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;logits&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># use shape info to pass max length info in remove padding mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">),</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">first_atten_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span><span class="o">.</span><span class="n">index</span><span class="p">(</span>
<span class="s1">&#39;attention&#39;</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">first_atten_layer</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">num_blocks</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="p">(</span>
<span class="n">num_blocks</span><span class="o">=</span><span class="n">num_blocks</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">head_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">allocate</span><span class="p">(</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">num_kv_heads_per_layer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span> <span class="c1"># As for now we enable cross paged kv and self paged kv to share the same tokens_per_block</span>
<span class="n">cross_num_blocks</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="n">sink_token_length</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="p">(</span>
<span class="n">num_blocks</span><span class="o">=</span><span class="n">cross_num_blocks</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">head_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_cross_attn_layer</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">allocate</span><span class="p">(</span>
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">num_kv_heads_per_cross_attn_layer</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span><span class="n">i</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cross_cache_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Without plugin, we need extra kv cache buffers.</span>
<span class="c1"># Because we don&#39;t support inplace update, so we need separate buffer for inputs and outputs.</span>
<span class="c1"># We can do reuse between different layers&#39; inputs and outputs, i.e. current layer&#39;s output can</span>
<span class="c1"># reuse previous layer&#39;s input memory. But this need one extra buffer as the guard.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span> <span class="c1"># Not applicable to cross KV buffers as it&#39;s constant</span>
<span class="n">i</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">trt_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">:</span>
<span class="c1"># PyTorch doesn&#39;t support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8.</span>
<span class="c1"># TODO: Remove this section when PyTorch support fp8 datatype</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">rnn_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">rnn_state_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_hidden_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;recurrent&#39;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">rnn_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">state_dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">conv_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">rnn_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">conv_state_ptr</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">rnn_state_ptr</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">lora_uids</span> <span class="o">=</span> <span class="n">lora_uids</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">&quot;-1&quot;</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">input_buffers</span><span class="p">(</span>
<span class="n">lora_uids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span>
<span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="n">M</span> <span class="o">=</span> <span class="n">max_num_tokens</span>
<span class="n">N</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span> <span class="o">=</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span>
<span class="n">itemsize</span> <span class="o">=</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span><span class="o">.</span><span class="n">itemsize</span>
<span class="n">alloc_bytes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span> <span class="o">*</span> <span class="n">itemsize</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span> <span class="o">=</span> <span class="n">ipc_nvls_allocate</span><span class="p">(</span>
<span class="n">alloc_bytes</span><span class="p">,</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Allocated NVLS IPC memory: </span><span class="si">{</span><span class="n">alloc_bytes</span><span class="si">}</span><span class="s1"> bytes&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">_allocate_empty_kv_cache_pools</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">num_blocks</span><span class="p">):</span>
<span class="c1"># Layers are homogeneous, use old kv cache shape</span>
<span class="n">unique_cache_pools</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">num_blocks</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">unique_cache_pools</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cache_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="c1"># Layers are not homogeneous, use new kv cache shape</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_heads_unique_counter</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">)</span>
<span class="k">for</span> <span class="n">kv_head</span><span class="p">,</span> <span class="n">num_layers</span> <span class="ow">in</span> <span class="n">kv_heads_unique_counter</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">num_blocks</span><span class="p">,</span>
<span class="n">num_layers</span><span class="p">,</span>
<span class="mi">2</span><span class="p">,</span>
<span class="n">kv_head</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">unique_cache_pools</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cache_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="k">return</span> <span class="n">unique_cache_pools</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">language_adapter_routings</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">]:</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">def</span><span class="w"> </span><span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_from_pointer</span><span class="p">(</span><span class="n">pointer</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="n">name</span><span class="p">:</span>
<span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">)</span>
<span class="p">})</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_bs</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">bs</span><span class="p">):</span>
<span class="c1"># this assumes dim0 to be bs and only overrides dim0 with given bs</span>
<span class="n">shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bs</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">host_runtime_perf_knobs</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;gpt_attention_plugin needs to set host_runtime_perf_knobs&quot;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_runtime_perf_knobs</span><span class="p">,</span> <span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_progress</span><span class="p">,</span> <span class="s1">&#39;host_context_progress&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">&#39;cache_indirection&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">&#39;position_ids&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="c1"># in context phase, need to generate cross kv cache, set to True</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span> <span class="s1">&#39;skip_cross_attn_blocks&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># see Attention&#39;s self.qkv output dim</span>
<span class="n">cross_kv_out_dim</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span>
<span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span>
<span class="n">cross_kv_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span>
<span class="n">cross_kv_out_dim</span><span class="p">,</span> <span class="p">)</span>
<span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cross_kv_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="n">cross_kv_reuse</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span><span class="p">,</span> <span class="s1">&#39;cross_kv_reuse&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">&#39;encoder_output&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">language_adapter_routings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">language_adapter_routings</span><span class="p">,</span>
<span class="s1">&#39;language_adapter_routings&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">],</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># cross-attention packed mask (used by fmha).</span>
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">pack_fmha_mask_by_input</span><span class="p">(</span>
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
<span class="s1">&#39;cross_attention_packed_mask&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># create a full 1 cross_attention_mask because it is necessary</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">cross_attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(),</span>
<span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">prod</span><span class="p">()),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s2">&quot;cross_attention_mask&quot;</span><span class="p">)</span>
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">pack_fmha_mask_by_input</span><span class="p">(</span>
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
<span class="s2">&quot;cross_attention_packed_mask&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">set_redrafter_ctx_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">add_tensor</span><span class="p">,</span> <span class="n">add_tensor_with_bs</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_output&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_input&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()],</span>
<span class="n">tasks</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">tasks_generation</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">buffer</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">buffer</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;kv_cache_block_offsets&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_offsets&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_pointers&#39;</span>
<span class="n">pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_mapping&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_pointers</span><span class="p">],</span> <span class="n">pool_pointers</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_mapping</span><span class="p">],</span> <span class="n">pool_mapping</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_buffer</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">cross_shape</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
<span class="n">cross_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">cross_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="o">*</span><span class="n">cross_shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]</span>
<span class="p">]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_kv_cache_block_offsets&#39;</span><span class="p">,</span>
<span class="n">cross_shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_block_offsets&#39;</span><span class="p">,</span>
<span class="n">cross_shape</span><span class="p">)</span>
<span class="n">cross_pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_pointers&#39;</span>
<span class="n">cross_pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_mapping&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_pointers</span><span class="p">],</span>
<span class="n">cross_pool_pointers</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_mapping</span><span class="p">],</span> <span class="n">cross_pool_mapping</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
<span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="n">kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span><span class="p">[</span><span class="n">idx</span><span class="p">]),</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">kv_cache_shape</span><span class="p">)</span>
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="mi">0</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
<span class="n">cross_kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">cross_kv_cache_shape</span><span class="p">)</span>
<span class="n">cross_present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_present</span><span class="p">],</span> <span class="n">cross_present</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;attention&#39;</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="c1"># when plugin is used, past_ket_value tensor does not need to be empty tensor</span>
<span class="c1"># because plugin does not care, and does not use this shape.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">!=</span> <span class="s1">&#39;recurrent&#39;</span><span class="p">:</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># conv state</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">conv_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">conv_state_shape</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">conv_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
<span class="c1"># rnn state</span>
<span class="n">rnn_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">&#39;slot_mapping&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="c1"># context request</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">device_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">device_request_types</span><span class="p">,</span> <span class="s1">&#39;device_request_types&#39;</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
<span class="c1"># field 0: past_key_value_length, field 1: is_context (deprecated). changed to [0], otherwise affects batch padded input mode</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="s1">&#39;host_sink_token_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_uc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">uc_ptr</span><span class="p">,</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_mc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">mc_ptr</span><span class="p">,</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_ipc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">get_ipc_ptrs</span><span class="p">(),</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
<span class="n">lora_weights</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_weights</span><span class="p">],</span> <span class="n">lora_weights</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">],</span> <span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">language_adapter_routings</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;_get_next_step_shape_buffer&quot;</span><span class="p">)</span>
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_from_pointer</span><span class="p">(</span><span class="n">pointer</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
<span class="n">name</span><span class="p">:</span>
<span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">)</span>
<span class="p">})</span>
<span class="k">def</span><span class="w"> </span><span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths_local</span><span class="p">,</span> <span class="s1">&#39;context_lengths&#39;</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">host_runtime_perf_knobs</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;gpt_attention_plugin needs to set host_runtime_perf_knobs&quot;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_runtime_perf_knobs</span><span class="p">,</span> <span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_progress</span><span class="p">,</span> <span class="s1">&#39;host_context_progress&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">&#39;cache_indirection&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">&#39;position_ids&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">hidden_size</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="s1">&#39;logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">],</span> <span class="s1">&#39;medusa_logits&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_output&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">,</span> <span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;flat_tokens&#39;</span><span class="p">],</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span>
<span class="n">input_ids_shape</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span>
<span class="n">input_ids_shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">&#39;input_ids&#39;</span><span class="p">,</span>
<span class="n">input_ids_shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">&#39;hidden_states_input&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># disable (or minimize) cross qkv computation at generation phase</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
<span class="c1"># disable</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span><span class="p">,</span> <span class="s1">&#39;cross_kv_reuse&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># minimize</span>
<span class="c1"># use TensorRT Empty Tensor to skip redundant computation</span>
<span class="c1"># 0 for generation phase, &gt;0 for context phase</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">encoder_output_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># OOTB path doesn&#39;t have kv cache for now, so this encoder_output is</span>
<span class="c1"># a must-have input. We just use the encoder_output</span>
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
<span class="c1"># in generation phase, cross kv cache is already filled during context phase, set to False</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="s1">&#39;cross_kv_cache_gen&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span> <span class="s1">&#39;skip_cross_attn_blocks&#39;</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">&#39;encoder_output&#39;</span><span class="p">,</span>
<span class="n">encoder_output_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">&#39;encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">language_adapter_routings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">language_adapter_routings</span><span class="p">,</span>
<span class="s1">&#39;language_adapter_routings&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">],</span>
<span class="s1">&#39;encoder_max_input_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cross_attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Empty packed mask is passed in the generation phase as it is not used.</span>
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="p">(</span><span class="n">cross_attention_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">31</span><span class="p">)</span> <span class="o">//</span> <span class="mi">32</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">&#39;cross_attention_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
<span class="s1">&#39;cross_attention_packed_mask&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># create a full 1 cross_attention_mask because it is necessary in generation phase</span>
<span class="n">add_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span>
<span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">prod</span><span class="p">()),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="s2">&quot;cross_attention_mask&quot;</span><span class="p">)</span>
<span class="c1"># Empty packed mask is passed in the generation phase as it is not used.</span>
<span class="n">add_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="s2">&quot;cross_attention_packed_mask&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;kv_cache_block_offsets&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_block_offsets&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_pointers&#39;</span>
<span class="n">pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_mapping&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_pointers</span><span class="p">],</span> <span class="n">pool_pointers</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_mapping</span><span class="p">],</span> <span class="n">pool_mapping</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_shape</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
<span class="n">cross_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">cross_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
<span class="o">*</span><span class="n">cross_shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]</span>
<span class="p">]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_kv_cache_block_offsets&#39;</span><span class="p">,</span>
<span class="n">cross_shape</span><span class="p">)</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_block_offsets&#39;</span><span class="p">,</span>
<span class="n">cross_shape</span><span class="p">)</span>
<span class="n">cross_pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_pointers&#39;</span>
<span class="n">cross_pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_mapping&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_pointers</span><span class="p">],</span>
<span class="n">cross_pool_pointers</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_mapping</span><span class="p">],</span> <span class="n">cross_pool_mapping</span><span class="p">)</span>
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">&#39;prompt_embedding_table&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">gen_tasks</span><span class="p">,</span> <span class="s1">&#39;tasks&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">&#39;prompt_vocab_size&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">attn_idx</span><span class="p">,</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">next_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="c1"># We will make current layer&#39;s output KV-cache overwrite previous layers input KV-cache</span>
<span class="c1"># buffer id: ... 5, 6, 7, 8, 9, ...</span>
<span class="c1"># layer n: out in</span>
<span class="c1"># layer n+1: out in</span>
<span class="c1"># layer n+2 out in</span>
<span class="c1"># And when finish a step, we will make every layer&#39;s in/out buffer index subtract 1 in</span>
<span class="c1"># a circular buffer way to make sure current outputs become next step&#39;s inputs.</span>
<span class="n">num_buffers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">input_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">attn_idx</span> <span class="o">-</span> <span class="p">(</span><span class="n">step</span> <span class="o">%</span> <span class="n">num_buffers</span><span class="p">))</span> <span class="o">%</span> <span class="n">num_buffers</span>
<span class="n">output_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_idx</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">num_buffers</span>
<span class="n">input_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">input_idx</span><span class="p">]</span>
<span class="n">output_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">output_idx</span><span class="p">]</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">input_name</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span>
<span class="n">next_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">output_name</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;cross_present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">!=</span> <span class="s1">&#39;recurrent&#39;</span><span class="p">:</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># conv state</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">],</span>
<span class="sa">f</span><span class="s1">&#39;present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># rnn state</span>
<span class="n">rnn_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;past_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">&#39;slot_mapping&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="c1"># generation requests</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;device_request_types&quot;</span><span class="p">)</span>
<span class="n">device_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">device_request_types</span><span class="p">,</span> <span class="s1">&#39;device_request_types&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># previous [past_kv_length, is_context] has been deprecated. only past_kv_length should be given here</span>
<span class="c1"># Note we should use max_context_length here to align to max -- but isn&#39;t this done in attn plugin&#39;s max_element() already?</span>
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="p">[</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="c1"># Sequence lengths are not used in the context phase actually.</span>
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s1">&#39;sequence_length&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="s1">&#39;host_sink_token_length&#39;</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;host_max_attention_window_sizes&#39;</span><span class="p">,</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">&#39;host_request_types&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span>
<span class="s1">&#39;host_context_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">&#39;all_reduce_workspace&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_uc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">uc_ptr</span><span class="p">,</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_mc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">mc_ptr</span><span class="p">,</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce_ipc_out&quot;</span><span class="p">):</span>
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">get_ipc_ptrs</span><span class="p">(),</span>
<span class="n">name</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="c1"># Since we are using a ping-pong context design and the lora weight remains constant within the same request,</span>
<span class="c1"># it is only necessary to set the lora weight for the first two steps.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="n">step</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
<span class="n">lora_module</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_module</span><span class="p">],</span> <span class="n">lora_module</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="s1">&#39;host_encoder_input_lengths&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># Spec Decoding mask and position offsets are fixed for the whole session for Medusa.</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_packed_mask&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_position_offsets&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">],</span>
<span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">)</span>
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">],</span> <span class="s1">&#39;spec_decoding_use&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">set_redrafter_gen_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">add_tensor</span><span class="p">,</span>
<span class="n">add_tensor_with_shape</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">return</span> <span class="n">tensors</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should contain the actual indices</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span> <span class="c1"># sub 1 from context_lengths for indices</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="p">(</span><span class="n">use_gpt_attention_plugin</span>
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">)</span> <span class="ow">and</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span>
<span class="n">perf_knob_tensor_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">:</span>
<span class="n">context_runtime_perf_knobs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># multi_block_mode</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span>
<span class="n">context_runtime_perf_knobs</span><span class="p">[</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># enable_context_fmha_fp32_acc</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_runtime_perf_knobs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;input_ids&#39;</span><span class="p">)</span>
<span class="n">pad_id</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;pad_id&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">pad_id</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;position_ids_base&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># NOTE: Generate random tensors using torch</span>
<span class="n">redrafter_prepare_random_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">initialize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;_prepare_generation_inputs&quot;</span><span class="p">)</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">):</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should be [bs * seq] and should contain the actual indices (starts from 1)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;last_token_ids_1s&quot;</span><span class="p">)</span>
<span class="c1"># update last_token_ids here (buffers already swapped)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># For Medusa, last_token_ids should be [bs, seq] and should contain the actual indices (starts from 0)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="p">(</span><span class="n">use_gpt_attention_plugin</span>
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">)</span> <span class="ow">and</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;last_token_ids_cumsum&quot;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;position_ids_update&quot;</span><span class="p">)</span>
<span class="c1"># set position_ids</span>
<span class="c1"># buffers are swapped but sequence_length is not updated at this point</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;position_ids_base&#39;</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;num_accepted_tokens&#39;</span><span class="p">]</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;packed_position_ids&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">]</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">context_lengths</span> <span class="o">+</span> <span class="n">step</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">gen_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">perf_knob_tensor_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">:</span>
<span class="n">gen_runtime_perf_knobs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># multi_block_mode</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span>
<span class="n">gen_runtime_perf_knobs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># enable_context_fmha_fp32_acc</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">gen_runtime_perf_knobs</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">ret</span><span class="p">[</span><span class="s1">&#39;position_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="c1"># buffers are already swapped</span>
<span class="c1"># convert spec_decoding_mask to spec_decoding_packed_mask</span>
<span class="n">redrafter_convert_spec_decoding_mask_to_packed_mask</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">])</span>
<span class="c1"># NOTE: Generate random tensors using torch</span>
<span class="n">redrafter_prepare_random_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">return</span> <span class="n">ret</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_cross_attention_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">):</span>
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">max_decoder_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">decoder_input_length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">cross_attention_mask</span><span class="p">[</span>
<span class="n">batch_idx</span><span class="p">][:</span><span class="n">decoder_input_length</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">local_mask_for_gen</span> <span class="o">=</span> <span class="n">cross_attention_mask</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">][</span>
<span class="n">decoder_input_length</span><span class="p">:,</span> <span class="p">:]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">local_mask_for_context</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span>
<span class="n">local_mask_for_context</span><span class="p">,</span>
<span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">(</span><span class="n">max_decoder_input_length</span> <span class="o">-</span> <span class="n">decoder_input_length</span><span class="p">)),</span>
<span class="s2">&quot;constant&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="n">local_mask_for_gen</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span>
<span class="n">local_mask_for_gen</span><span class="p">,</span>
<span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">(</span><span class="n">max_decoder_input_length</span> <span class="o">-</span> <span class="n">decoder_input_length</span><span class="p">)),</span>
<span class="s2">&quot;constant&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="n">cross_attention_mask_for_context</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">local_mask_for_context</span><span class="p">)</span>
<span class="c1"># add additional dimension for batch size.</span>
<span class="n">cross_attention_mask_for_gen</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">local_mask_for_gen</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">cross_attention_mask_for_context</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
<span class="n">cross_attention_mask_for_gen</span><span class="p">)</span>
<div class="viewcode-block" id="GenerationSession.pp_communicate_new_tokens">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_new_tokens">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">pp_communicate_new_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">cache_indir</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pg</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.pp_communicate_final_output_ids">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_final_output_ids">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">pp_communicate_final_output_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.finalize_decoder">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.finalize_decoder">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">finalize_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="c1"># output shape of self.gather_tree: [batch_size, beam_width, output_len]</span>
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span> <span class="ow">and</span> <span class="n">in_progress</span><span class="p">:</span>
<span class="c1"># self.gather_tree modifies these args.</span>
<span class="c1"># In streaming mode, this results in incorrect decoding in the following steps.</span>
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">beam_hyps_args</span><span class="p">)</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="o">*</span><span class="n">beam_hyps_args</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="c1"># Communicate ranks in Pipeline Parallelism</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_final_output_ids</span><span class="p">(</span>
<span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">return</span> <span class="n">final_output_ids</span></div>
<div class="viewcode-block" id="GenerationSession.find_best_medusa_path">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.find_best_medusa_path">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">find_best_medusa_path</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">next_logits</span><span class="p">,</span>
<span class="n">temp</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">best_path</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">best_path_len</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">next_tokens</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">zero_pad</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">temp</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">new_tokens_raw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span>
<span class="n">next_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span>
<span class="p">)</span> <span class="c1"># TODO: can be done by treating [bs, nT, vocab] as [bs*nT, vocab] and using decoderOp?</span>
<span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">new_tokens_raw</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">input_paths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">new_paths</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">equality</span> <span class="o">=</span> <span class="n">input_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">paths_correct_len</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">equality</span><span class="o">.</span><span class="n">int</span><span class="p">(),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">paths_correct_len</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">paths_correct_len</span><span class="p">)</span>
<span class="n">next_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]][:</span><span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span> <span class="n">next_tokens</span></div>
<div class="viewcode-block" id="GenerationSession.filter_medusa_logits">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.filter_medusa_logits">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">filter_medusa_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span>
<span class="n">medusa_logits</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> medusa_logits is of shape [nMH, bs, nMT+1, vocab]</span>
<span class="sd"> Returns [nMH, bs, vocab]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">filtered_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span><span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">],</span> <span class="n">best_path_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">filtered_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
<span class="k">return</span> <span class="n">filtered_logits</span></div>
<div class="viewcode-block" id="GenerationSession.get_next_medusa_tokens">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_next_medusa_tokens">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">get_next_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">):</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="n">next_medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="p">]</span> <span class="c1"># dummy token for now, TODO: update tree_ids and remove this</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">):</span>
<span class="n">medusa_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">next_medusa_logits</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:],</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">indices</span>
<span class="n">next_medusa_tokens</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">medusa_token</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">next_medusa_tokens</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">next_medusa_tokens</span></div>
<div class="viewcode-block" id="GenerationSession.locate_accepted_draft_tokens">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.locate_accepted_draft_tokens">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">locate_accepted_draft_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span>
<span class="n">draft_paths</span><span class="p">):</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;locate_accepted_draft_tokens&quot;</span><span class="p">)</span>
<span class="n">best_path_len_tensor</span> <span class="o">=</span> <span class="n">best_path_len</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">best_path_len</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
<span class="n">best_path_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">accepted_draft_token_counts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span>
<span class="n">best_path_len_tensor</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">best_path_len_tensor</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
<span class="n">accepted_draft_token_offsets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">accepted_draft_token_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span>
<span class="n">accepted_draft_token_counts</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">accepted_draft_token_offsets_cpu</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
<span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
<span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">seq_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">cur_draft_paths</span> <span class="o">=</span> <span class="n">draft_paths</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="k">else</span> <span class="n">draft_paths</span><span class="p">[</span>
<span class="n">seq_idx</span><span class="p">]</span>
<span class="n">seq_start</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
<span class="n">seq_end</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
<span class="n">seq_accepted_draft_count</span> <span class="o">=</span> <span class="n">seq_end</span> <span class="o">-</span> <span class="n">seq_start</span>
<span class="n">best_path_idx</span> <span class="o">=</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
<span class="n">seq_accepted_token_indices</span> <span class="o">=</span> <span class="n">cur_draft_paths</span><span class="p">[</span>
<span class="n">best_path_idx</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">seq_accepted_draft_count</span><span class="p">]</span>
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">[</span>
<span class="n">seq_start</span><span class="p">:</span><span class="n">seq_end</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_accepted_token_indices</span> <span class="o">-</span> <span class="mi">1</span>
<span class="c1"># print(&quot;KV offsets &amp; indices&quot;, accepted_draft_token_offsets,</span>
<span class="c1"># packed_accepted_draft_tokens_indices,)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">return</span> <span class="n">accepted_draft_token_offsets</span><span class="p">,</span> <span class="n">packed_accepted_draft_tokens_indices</span></div>
<div class="viewcode-block" id="GenerationSession.update_output_ids_by_offset">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_output_ids_by_offset">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_generated_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">):</span>
<span class="c1"># output_ids [batch_size, padded_input_length]</span>
<span class="c1"># new_generated_ids [batch_size, padded_accepted_length]</span>
<span class="c1"># offsets [batch_size]</span>
<span class="c1"># FIXME: using fused kernel to update the padded output ids.</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]:(</span>
<span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="p">)]</span> <span class="o">=</span> <span class="n">new_generated_ids</span><span class="p">[</span><span class="n">b</span><span class="p">][:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span>
<span class="k">return</span></div>
<div class="viewcode-block" id="GenerationSession.next_medusa_input_ids">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.next_medusa_input_ids">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">next_medusa_input_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># self.new_tokens [batch_size, padded_accepted_length]</span>
<span class="c1"># self.accept_lengths [batch_size]</span>
<span class="c1"># self.medusa_new_tokens [batch_size, num_draft_tokens]</span>
<span class="c1"># FIXME: using fused kernel to generate the new medusa input ids.</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
<span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:]</span></div>
<div class="viewcode-block" id="GenerationSession.reorder_kv_cache_for_beam_search">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.reorder_kv_cache_for_beam_search">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">reorder_kv_cache_for_beam_search</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># Do nothing.</span>
<span class="k">return</span>
<span class="c1"># WAR: This degrades the latency performance in beam search</span>
<span class="c1"># due to memcpy. Recommend to use gpt attention plugin instead.</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">functools</span>
<span class="n">numel</span> <span class="o">=</span> <span class="n">functools</span><span class="o">.</span><span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">cache_shape</span><span class="p">)</span>
<span class="c1"># attention layer num + 1 extra buffer.</span>
<span class="n">num_buffers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">:</span>
<span class="c1"># Cyclic buffers, an output becomes the next step&#39;s input.</span>
<span class="n">input_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">-</span> <span class="p">(</span><span class="n">step</span> <span class="o">%</span> <span class="n">num_buffers</span><span class="p">))</span> <span class="o">%</span> <span class="n">num_buffers</span>
<span class="n">presents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">input_idx</span><span class="p">]]</span>
<span class="n">presents</span> <span class="o">=</span> <span class="n">presents</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">numel</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">cache_shape</span><span class="p">)</span>
<span class="c1"># parent_ids = (batch, beam, max_seq_len)</span>
<span class="n">parent_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">[</span><span class="o">...</span><span class="p">,</span>
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">batch_beam</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">):</span>
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch_beam</span> <span class="o">//</span> <span class="n">beam_width</span>
<span class="k">if</span> <span class="n">parent_ids</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">]</span> <span class="o">!=</span> <span class="n">batch_beam</span> <span class="o">%</span> <span class="n">beam_width</span><span class="p">:</span>
<span class="c1"># Update past kv cache to parent beam&#39;s cache.</span>
<span class="n">src_bbid</span> <span class="o">=</span> <span class="n">batch</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">+</span> <span class="n">parent_ids</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">]</span>
<span class="n">presents</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">presents</span><span class="p">[</span><span class="n">src_bbid</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span></div>
<span class="c1"># OPTIMIZE: need to optimize this early-stop workflow.</span>
<div class="viewcode-block" id="GenerationSession.early_stop_criteria">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.early_stop_criteria">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">early_stop_criteria</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">):</span>
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">continue</span>
<span class="c1"># output sequence length criteria.</span>
<span class="n">prev_total_output_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="c1"># end id criteria.</span>
<span class="n">end_id_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
<span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="n">should_stop_with_end_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">end_id_mask</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span>
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
<span class="c1"># update accept lengths for the current step.</span>
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="n">prev_total_output_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
<span class="k">if</span> <span class="n">should_stop_with_end_id</span><span class="p">:</span>
<span class="c1"># get the position of first end_id.</span>
<span class="n">end_id_pos</span> <span class="o">=</span> <span class="p">(</span><span class="n">end_id_mask</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end_id_pos</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">)</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.medusa_decode_and_verify">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.medusa_decode_and_verify">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">medusa_decode_and_verify</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;medusa_logits&#39;</span><span class="p">]</span>
<span class="n">best_path</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">best_path_lengths</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># logits buffer is of shape [bs, medusa_tokens+1, vocab]</span>
<span class="c1"># but during context phase, we get only [bs, 1, vocab] but contiguous</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">next_main_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">next_main_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">next_main_token_logits</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
<span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">next_main_token</span>
<span class="c1"># NOTE: only one token&#39;s medusa logit will be written in.</span>
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
<span class="o">-</span><span class="bp">self</span><span class="o">.</span>
<span class="n">num_draft_tokens</span><span class="p">:]]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">next_main_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">find_best_medusa_path</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span>
<span class="n">next_token_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_lengths</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">next_main_tokens</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1">#FIXME end id padding.</span>
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_medusa_logits</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">medusa_logits</span><span class="p">)</span>
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
<span class="o">-</span><span class="bp">self</span><span class="o">.</span>
<span class="n">num_draft_tokens</span><span class="p">:]]</span>
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span></div>
<div class="viewcode-block" id="GenerationSession.process_logits_including_draft">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.process_logits_including_draft">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">process_logits_including_draft</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span>
<span class="n">next_step_buffer</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> 1. Process logits to tokens and validate (Medusa) or process outputs (ReDrafter)</span>
<span class="sd"> 2. Extract early stop criteria here : self.accept_length</span>
<span class="sd"> 3. Update output ids : needs self.new_tokens and past_sequence_length</span>
<span class="sd"> 4. Get next input_ids : self.[new_tokens, accept_lengths, medusa_output_tokens]</span>
<span class="sd"> 5. Update KV cache : self.[sequence_length, num_draft_tokens]</span>
<span class="sd"> 6. Update sequence_length_buffer and past_kv_length</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="c1"># NOTE: this function call also updates self.[accept_lengths, new_tokens, medusa_output_tokens]</span>
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_decode_and_verify</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
<span class="n">last_draft_paths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span>
<span class="c1"># print(best_path, self.new_tokens, self.medusa_output_tokens)</span>
<span class="n">last_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="k">if</span> <span class="n">step</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">cur_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="c1"># buffers are swapped at this point</span>
<span class="n">last_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;next_draft_tokens&#39;</span><span class="p">]</span>
<span class="n">new_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;draft_tokens&#39;</span><span class="p">]</span>
<span class="n">last_draft_paths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s2">&quot;next_draft_indices&quot;</span><span class="p">]</span>
<span class="n">last_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;next_spec_decoding_generation_lengths&#39;</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">cur_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="s1">&#39;spec_decoding_generation_lengths&#39;</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span> <span class="o">=</span> <span class="n">process_redrafter_outputs</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">last_draft_tokens</span><span class="p">,</span> <span class="n">new_draft_tokens</span><span class="p">)</span>
<span class="c1"># NOTE: stop criteria</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;early_stop_check&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">early_stop_criteria</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
<span class="n">should_stop</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="c1"># NOTE: self.accept_lengths are the lengths of accepted tokens in the current step</span>
<span class="c1"># NOTE: self.sequence_length_buffer = num_past_kv_cache (accepted) + accept_lengths</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;update_output_ids&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">next_medusa_input_ids</span><span class="p">()</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">best_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">best_path_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">accepted_draft_token_offsets</span><span class="p">,</span> <span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">locate_accepted_draft_tokens</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">last_draft_paths</span><span class="p">)</span>
<span class="c1"># update the KV cache</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;kv_update&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="n">accepted_draft_token_offsets</span><span class="p">,</span>
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="n">last_draft_tokens_len</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">+</span> <span class="n">cur_draft_tokens_len</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="n">cur_draft_tokens_len</span> <span class="o">+</span> <span class="mi">1</span>
<span class="c1"># NOTE: set the accepted tokens for the last step.</span>
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="c1"># remove num_draft_tokens for next generation.</span>
<span class="c1"># Runtime: denotes kv cache length start positions.</span>
<span class="c1"># Output: denotes the length of sequence length (input ids + output ids)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span>
<span class="k">if</span> <span class="n">next_step_buffer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">&#39;host_past_key_value_lengths&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">should_stop</span></div>
<div class="viewcode-block" id="GenerationSession.handle_per_step">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.handle_per_step">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">handle_per_step</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_gen</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">next_step_tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">],</span>
<span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
<span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;=================================== STEP </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2"> ==================================&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span>
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">position_ids_raw</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">skip_cross_attn_blocks</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;skip_cross_attn_blocks&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">language_adapter_routings</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;language_adapter_routings&#39;</span><span class="p">,</span>
<span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_context_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">pad_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">,</span>
<span class="n">eos_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">)</span>
<span class="k">if</span> <span class="n">position_ids_raw</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># default iota position ids</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># user input position ids</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">position_ids_raw</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padded_position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">pad_sequence</span><span class="p">(</span>
<span class="n">position_ids_raw</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">padded_position_ids</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">host_context_progress</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
<span class="n">beam_width</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
<span class="n">beam_width</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">ctx_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_context_shape_buffer</span><span class="p">(</span>
<span class="n">input_ids</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
<span class="n">this_src_cache_indirection</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">context_runtime_perf_knobs</span><span class="p">,</span>
<span class="n">host_context_progress</span><span class="o">=</span><span class="n">host_context_progress</span><span class="p">,</span>
<span class="n">skip_cross_attn_blocks</span><span class="o">=</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span>
<span class="n">language_adapter_routings</span><span class="o">=</span><span class="n">language_adapter_routings</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">ctx_context</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_tensors</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">ctx_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="c1"># context mode, clean cuda graph instances</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="kc">False</span><span class="p">:</span> <span class="c1"># TODO: after TRT bug is fixed</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_check_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
<span class="c1"># dynamic_decoder currently use torch&#39;s current stream, so must let TRT enqueue use same stream here</span>
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># launch cuda graph</span>
<span class="n">CUASSERT</span><span class="p">(</span>
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Executing TRT engine failed step=</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">!&quot;</span><span class="p">)</span>
<span class="c1"># TODO: remove this Windows WAR after https://nvbugs/4460474 is fixed.</span>
<span class="k">if</span> <span class="n">platform</span><span class="o">.</span><span class="n">system</span><span class="p">()</span> <span class="o">==</span> <span class="s2">&quot;Windows&quot;</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span>
<span class="n">context_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># gather last token of context</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="c1"># reshape self.buffer[&#39;logits&#39;] from [bs, max_context_length, vocab]</span>
<span class="c1"># to [1, bs * max_context_length, vocab]</span>
<span class="c1"># Note that the data are put in the buffer without padding although</span>
<span class="c1"># the allocated buffer has padding.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">index</span><span class="o">=</span><span class="n">last_token_ids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span>
<span class="c1"># these tiled tensors are returned by handle_per_step(), so they can relay to the next generation calls</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">encoder_input_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="c1"># Move tiling before logit computing of context</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">:</span>
<span class="c1"># Note: this tiles both self attn cache and cross attn</span>
<span class="c1"># cache! both names contain &quot;present_key_value&quot;</span>
<span class="k">if</span> <span class="s2">&quot;present_key_value&quot;</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># In the OOTB path, KV cache should be contiguously</span>
<span class="c1"># tiled since TRT engine allocates past_kv cache of</span>
<span class="c1"># length context_length, i.e., we need a buffer of</span>
<span class="c1"># shape (batch * beam, 2, heads, context_length, head_size).</span>
<span class="n">b</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
<span class="n">numel</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">b</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="p">(</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">)</span> <span class="o">*</span> <span class="n">d</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_contiguous_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">numel</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
<span class="n">generation_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
<span class="n">generation_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="c1"># Initialize sequence_lengths (no paddings) for the generation phase.</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span> <span class="c1"># Medusa/ReDrafter has its own logic</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="c1"># to simplify some processing logic, always swap buffers after execution</span>
<span class="n">exchange_redrafter_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="c1"># NOTE: handle next step.</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># Set shape and address for the next step</span>
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_generation_inputs</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
<span class="n">num_beams</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">position_ids_raw</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;position_ids&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">position_ids_raw</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;last_token_ids&#39;</span><span class="p">)</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">gen_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">,</span>
<span class="kc">None</span><span class="p">)</span>
<span class="n">host_context_progress</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="c1"># Prepare for the next step, and always allocate 1 token slot.</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="c1"># Iterate to the next step in KV cache manager.</span>
<span class="c1"># Increase number of tokens for all unfinished sequences.</span>
<span class="c1"># And allocate new blocks if needed.</span>
<span class="c1"># We set this to False for all sequences, since we use only length criterion to stop now</span>
<span class="c1"># OPTIMIZE: find a better of adding multiple tokens for paged kv cache.</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;paged_kv_alloc&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">+</span>
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># Allocate kv cache token slots for next step.</span>
<span class="c1"># Make sure there are always &gt; (num_draft_tokens + 1) free token slots.</span>
<span class="c1"># Allocate (num_draft_tokens + 1) * 2 for safety as we don&#39;t know the current step or next step&#39;s accepted lengths.</span>
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span>
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;paged_kv_post_alloc&quot;</span><span class="p">)</span>
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
<span class="n">beam_width</span><span class="p">)</span>
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="n">next_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">cross_attention_mask_for_gen</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># cross_attention_mask_for_gen shape [batch_size, max_output_length, max_encoder_input_length]</span>
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span>
<span class="k">if</span> <span class="n">decode_step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">decode_step</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="p">[:,</span> <span class="p">(</span>
<span class="n">decode_step</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="p">:]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="p">[:,</span> <span class="p">(</span>
<span class="n">decode_step</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span><span class="n">decode_step</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_next_step_shape_buffer</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">,</span>
<span class="n">step</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">position_ids</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">,</span>
<span class="n">cross_attention_mask_step</span><span class="p">,</span>
<span class="n">next_src_cache_indirection</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">gen_runtime_perf_knobs</span><span class="p">,</span>
<span class="n">host_context_progress</span><span class="o">=</span><span class="n">host_context_progress</span><span class="p">,</span>
<span class="n">skip_cross_attn_blocks</span><span class="o">=</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span>
<span class="n">language_adapter_routings</span><span class="o">=</span><span class="n">language_adapter_routings</span><span class="p">)</span>
<span class="c1"># there are some tensors created inside the _get_next_step_shape_buffer, not owned by any object</span>
<span class="c1"># needs to pro-long the life time of the tensors inside the next_step_tensors array</span>
<span class="c1"># otherwise, it maybe released before the next step actually enqueued</span>
<span class="c1"># one way to prolong it is to return the list, and destroy it in next step by assigning new values</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">&quot;_set_tensors&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
<span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="s2">&quot;verbose&quot;</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">print_context_info</span><span class="p">(</span>
<span class="n">next_context</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">next_context</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
<span class="n">next_context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_including_draft</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">logits</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_including_draft</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">logits_processor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits_processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span>
<span class="n">logits</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">&#39;logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">logits</span>
<span class="c1"># [batch_size x beam_width, vocab_size_padded] -&gt; [batch_size, beam_width, vocab_size_padded]</span>
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="n">max_context_length</span>
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span> <span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_data</span>
<span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span> <span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_data</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
<span class="n">next_token_logits</span><span class="p">,</span> <span class="n">decode_step</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
<span class="n">ite</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
<span class="n">max_stop_words_len</span><span class="p">,</span> <span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
<span class="n">max_bad_words_len</span><span class="p">,</span> <span class="n">this_src_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
<span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reorder_kv_cache_for_beam_search</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
<span class="k">if</span> <span class="n">stopping_criteria</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">stopping_criteria</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_is_profiling</span><span class="p">():</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">context</span><span class="o">.</span><span class="n">report_to_profiler</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Runtime report to profiler failed.&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_insert_step_to_profiler</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_new_tokens</span><span class="p">(</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="k">if</span> <span class="p">(</span><span class="n">step</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">()):</span>
<span class="c1"># Free all blocks in all sequences.</span>
<span class="c1"># With in-flight batching and while loop we&#39;ll free some sequences, when they are done</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dump_debug_buffers</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
<span class="k">if</span> <span class="n">next_step_tensors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">next_step_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">}</span>
<span class="k">return</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span></div>
<div class="viewcode-block" id="GenerationSession.dump_debug_buffers">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.dump_debug_buffers">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">dump_debug_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># restricted written tensors according to filter</span>
<span class="n">debug_tensor_names</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">debug_tensor_names</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">all</span><span class="p">([</span><span class="n">kk</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">kk</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="p">]):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="n">debug_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;tllm_debug/PP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span><span class="si">}</span><span class="s2">/TP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="si">}</span><span class="s2">/CP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_rank</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">debug_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="c1"># convert tensor name to valid file name</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Saving: &quot;</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
<span class="n">fname</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">,</span> <span class="s2">&quot;.&quot;</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">float</span><span class="p">())</span>
<span class="n">np</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.npy&quot;</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">txt_format</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">%d</span><span class="s2">&quot;</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">]</span> <span class="k">else</span> <span class="s1">&#39;</span><span class="si">%.18e</span><span class="s1">&#39;</span>
<span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span>
<span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.txt&quot;</span><span class="p">,</span>
<span class="n">t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="c1"># savetxt accepts 2 dims only</span>
<span class="n">fmt</span><span class="o">=</span><span class="n">txt_format</span><span class="p">)</span></div>
<div class="viewcode-block" id="GenerationSession.decode_regular">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_regular">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">decode_regular</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_generation_logits</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;log_probs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span>
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;cum_log_probs&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;steps_to_finish&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">num_steps</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;medusa_output_tokens&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;accept_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mf">0.0</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;medusa_output_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="n">benchmark_profiler</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;benchmark_profiler&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">is_recording_perf_profile</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_profiler</span><span class="p">()</span>
<span class="k">def</span><span class="w"> </span><span class="nf">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler_obj</span><span class="p">,</span> <span class="n">step_count</span><span class="p">):</span>
<span class="k">if</span> <span class="n">benchmark_profiler_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">&#39;last_token&#39;</span><span class="p">)</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_elapsed_time</span><span class="p">(</span>
<span class="s1">&#39;first_token&#39;</span><span class="p">,</span> <span class="s1">&#39;last_token&#39;</span><span class="p">,</span> <span class="s1">&#39;generation_time&#39;</span><span class="p">)</span>
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">add_aux_info</span><span class="p">(</span><span class="s1">&#39;generation_step_count&#39;</span><span class="p">,</span>
<span class="n">step_count</span><span class="p">)</span>
<span class="c1"># prepare cross attention mask.</span>
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span> <span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_cross_attention_mask</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="c1"># When we use plugin, the data type of cross_attention_mask is bool.</span>
<span class="c1"># When we don&#39;t use plugin, the data type of cross_attention_mask is int32</span>
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_context</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="o">=</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="o">=</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="o">=</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="o">=</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_context</span><span class="o">=</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_gen</span><span class="o">=</span><span class="n">cross_attention_mask_for_gen</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
<span class="n">next_step_tensors</span><span class="o">=</span><span class="n">next_step_tensors</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">&#39;first_token&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="n">generation_phase_step_count</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
<span class="n">outputs_generation_logits</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">generation_logits</span><span class="p">)</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="c1"># just hack away for now</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span>
<span class="n">max_seq_length</span> <span class="o">-</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">,</span> <span class="s2">&quot;the custom decoder doesn&#39;t support medusa/redrafter.&quot;</span>
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">final_output_ids</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;generation_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_stream">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_stream">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">decode_stream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;output_ids&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span>
<span class="s1">&#39;sequence_lengths&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
<span class="n">outputs</span><span class="p">[</span><span class="s1">&#39;context_logits&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="c1"># prepare cross attention mask.</span>
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span> <span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_cross_attention_mask</span><span class="p">(</span>
<span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">)</span>
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="o">=</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="o">=</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">cross_kv_cache_block_offsets</span><span class="o">=</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="o">=</span>
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_context</span><span class="o">=</span>
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
<span class="n">cross_attention_mask_for_gen</span><span class="o">=</span><span class="n">cross_attention_mask_for_gen</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
<span class="n">next_step_tensors</span><span class="o">=</span><span class="n">next_step_tensors</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="n">scfg</span><span class="p">,</span>
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
<span class="k">return</span>
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="n">final_output_ids</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">yield</span> <span class="kc">None</span></div>
<div class="viewcode-block" id="GenerationSession.decode_batch">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_batch">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">decode_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_prepare_input_ids</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">,</span>
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="c1"># As dynamic_decoder uses torch&#39;s current stream, we must ensure it runs on the same stream that</span>
<span class="c1"># dynamic_decoder was set up with</span>
<div class="viewcode-block" id="GenerationSession.decode">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode">[docs]</a>
<span class="nd">@cuda_stream_guard</span>
<span class="k">def</span><span class="w"> </span><span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stop_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">bad_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">beam_width</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> \
<span class="s2">&quot;Given batch size is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new batch size to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">max_context_length</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span> \
<span class="s2">&quot;Given input length is large then the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new max_context_length to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span><span class="p">,</span> \
<span class="s2">&quot;Given beam width is different from the one used in setup(),&quot;</span> \
<span class="s2">&quot;rerun the setup function with the new beam width to avoid buffer overflow.&quot;</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">&lt;=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> \
<span class="s2">&quot;Given sink token length is larger than shortest context length,&quot;</span> \
<span class="s2">&quot;rerun the setup function with a smaller sink token length.&quot;</span>
<span class="n">ite</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># index of local batches, will always be 0 if pp_size = 1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Packed 2D input must have shape [1, &lt;sum of input lengths&gt;]&quot;</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">__setup_decoder</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Buffer not allocated, please call setup first!&#39;</span><span class="p">)</span>
<span class="n">sequence_limit_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="c1"># Sequence_lengths for the dynamic decoder still has the input paddings.</span>
<span class="n">sequence_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">cache_indirections</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="p">),</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="p">]</span> <span class="c1"># ping-pong buffers</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
<span class="c1"># Init KV cache block manager</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
<span class="n">num_blocks</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_pointers&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">get_kv_cache_pool_pointers</span><span class="p">(</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_mapping&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">pool_mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span> <span class="o">=</span> <span class="n">PoolsKVCacheManager</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">pools_metadata</span><span class="p">,</span>
<span class="n">max_blocks_per_seq</span><span class="p">,</span>
<span class="n">num_blocks</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">max_attention_window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">sink_token_len</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_num_blocks</span><span class="p">,</span> <span class="n">max_cross_blocks_per_seq</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="n">sink_token_length</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_pointers&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">get_kv_cache_pool_pointers</span><span class="p">(</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
<span class="sa">f</span><span class="s1">&#39;host_cross_kv_cache_pool_mapping&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">pool_mapping</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span> <span class="o">=</span> <span class="n">PoolsKVCacheManager</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">pools_metadata</span><span class="p">,</span>
<span class="n">max_cross_blocks_per_seq</span><span class="p">,</span>
<span class="n">cross_num_blocks</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">max_attention_window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">sink_token_len</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
<span class="c1"># Add sequences to the manager</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span>
<span class="n">generation_sequence</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
<span class="n">cross_generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span>
<span class="n">cross_generation_sequence</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
<span class="n">always_share_across_beam</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="c1"># cross attention paged kv cache should always share the context blocks across beams</span>
<span class="c1"># due to the fact that we are not adding new key/value cache to cross kv in generation</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">history_max_seq_length</span> <span class="o">=</span> <span class="p">[</span><span class="n">max_context_length</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span> <span class="o">=</span> <span class="n">KVCacheUpdater</span><span class="p">()</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_paged_kv_cache</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;host_kv_cache_pool_pointers&#39;</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">past_key_value_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">&#39;present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
<span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_linear_kv_cache</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">past_key_value_list</span><span class="p">)</span>
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">stop_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">stop_words_list</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">stop_words_list</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">max_stop_words_len</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">stop_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_stop_words_len</span> <span class="o">*</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">element_size</span><span class="p">(</span>
<span class="p">)</span>
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">stop_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">stop_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
<span class="n">max_stop_words_len</span><span class="p">)</span>
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">bad_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">bad_words_list</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">bad_words_list</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
<span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
<span class="n">max_bad_words_len</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">bad_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_bad_words_len</span> <span class="o">*</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">bad_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">bad_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
<span class="n">max_bad_words_len</span><span class="p">)</span>
<span class="c1"># start context phase</span>
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_stream</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="o">=</span><span class="n">output_sequence_lengths</span><span class="p">,</span>
<span class="n">return_dict</span><span class="o">=</span><span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="o">=</span><span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_regular</span><span class="p">(</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
<span class="n">output_sequence_lengths</span><span class="o">=</span><span class="n">output_sequence_lengths</span><span class="p">,</span>
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
<span class="n">return_dict</span><span class="o">=</span><span class="n">return_dict</span><span class="p">,</span>
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
<span class="n">cross_attention_mask</span><span class="o">=</span><span class="n">cross_attention_mask</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
<span class="p">)</span></div>
</div>
<div class="viewcode-block" id="ChatGLMGenerationSession">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ChatGLMGenerationSession">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">ChatGLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">model_config</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="p">,</span>
<span class="n">stream</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_context_length&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">input_lengths_acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">context_lengths</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span><span class="n">input_lengths_acc</span><span class="p">[</span>
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="c1"># specialization for GLM series models</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">][</span>
<span class="mi">0</span><span class="p">:</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">kwargs</span><span class="p">[</span>
<span class="s2">&quot;input_ids&quot;</span><span class="p">][</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span><span class="p">]</span>
<span class="p">):</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span><span class="p">])</span> <span class="o">+</span>
<span class="n">length</span><span class="p">]</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
<span class="p">]</span>
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">)</span>
<span class="c1"># specialization for GLM series models</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;pad_id&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
<span class="p">]</span>
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">perf_knob_tensor_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
<span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">:</span> <span class="n">context_runtime_perf_knobs</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;step&#39;</span><span class="p">)</span>
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;num_beams&#39;</span><span class="p">)</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
<span class="k">return</span> <span class="n">new_tensor</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">2</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span><span class="p">],</span>
<span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">generation_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span>
<span class="n">perf_knob_tensor_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;position_ids&#39;</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
<span class="s1">&#39;last_token_ids&#39;</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
<span class="s1">&#39;host_runtime_perf_knobs&#39;</span><span class="p">:</span> <span class="n">generation_runtime_perf_knobs</span>
<span class="p">}</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">inputs</span><span class="p">[</span><span class="s1">&#39;attention_mask&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
<span class="k">return</span> <span class="n">inputs</span></div>
<div class="viewcode-block" id="QWenForCausalLMGenerationSession">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">QWenForCausalLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">global_max_input_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span><span class="p">,</span>
<span class="n">global_max_output_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">,</span>
<span class="n">mapping</span><span class="p">,</span>
<span class="n">debug_mode</span><span class="p">,</span>
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="n">debug_tensors_to_save</span><span class="p">,</span>
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="n">cuda_graph_mode</span><span class="p">,</span>
<span class="n">stream</span><span class="o">=</span><span class="n">stream</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_input_length</span> <span class="o">=</span> <span class="n">global_max_input_length</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">=</span> <span class="n">global_max_output_length</span>
<div class="viewcode-block" id="QWenForCausalLMGenerationSession.generate">
<a class="viewcode-back" href="../../../legacy/python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession.generate">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">generate</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">max_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">max_new_tokens</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_new_tokens</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">-</span> <span class="n">max_input_length</span><span class="p">)</span>
<span class="c1"># setup batch_size, max_input_length, max_output_len</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">input_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_length</span><span class="p">,</span>
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_new_tokens</span><span class="p">)</span>
<span class="n">output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">sampling_config</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="k">if</span> <span class="n">runtime_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">output_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">return</span> <span class="n">outputs</span></div>
</div>
</pre></div>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
</div>
</footer>
</div>
<div class="bd-sidebar-secondary"></div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item">
<a class="footer-brand logo" href="https://www.nvidia.com">
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
</a></div>
<div class="footer-item">
<div class="footer-links">
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Your Privacy Choices</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
</div>
</div>
<div class="footer-item">
<p class="copyright">
Copyright © 2025, NVidia.
<br/>
</p>
</div>
<div class="footer-item">
<div class="extra_footer">
<p>Last updated on November 23, 2025.</p>
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/a761585">a761585</a>.</p>
</div></div>
</div>
</div>
</footer>
</body>
</html>