mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
5470 lines
874 KiB
HTML
5470 lines
874 KiB
HTML
|
|
|
|
<!DOCTYPE html>
|
|
|
|
|
|
<html lang="en" data-content_root="../../../" >
|
|
|
|
<head>
|
|
<meta charset="utf-8" />
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
<title>tensorrt_llm.runtime.generation — TensorRT-LLM</title>
|
|
|
|
|
|
|
|
<script data-cfasync="false">
|
|
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
|
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
|
</script>
|
|
<!--
|
|
this give us a css class that will be invisible only if js is disabled
|
|
-->
|
|
<noscript>
|
|
<style>
|
|
.pst-js-only { display: none !important; }
|
|
|
|
</style>
|
|
</noscript>
|
|
|
|
<!-- Loaded before other Sphinx assets -->
|
|
<link href="../../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
|
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
|
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=8f2a1f02" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/styles/nvidia-sphinx-theme.css?v=df3ac72c" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css?v=76b2166b" />
|
|
<link rel="stylesheet" type="text/css" href="../../../_static/autodoc_pydantic.css" />
|
|
|
|
<!-- So that users can add custom icons -->
|
|
<script src="../../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
|
<!-- Pre-loaded scripts that we'll load fully later -->
|
|
<link rel="preload" as="script" href="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
|
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
|
|
|
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
|
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
|
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
|
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
|
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
|
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/runtime/generation';</script>
|
|
<script>
|
|
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
|
|
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
|
|
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '0.21.0rc0';
|
|
DOCUMENTATION_OPTIONS.show_version_warning_banner =
|
|
false;
|
|
</script>
|
|
<link rel="icon" href="../../../_static/favicon.png"/>
|
|
<link rel="index" title="Index" href="../../../genindex.html" />
|
|
<link rel="search" title="Search" href="../../../search.html" />
|
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
|
<meta name="docsearch:language" content="en"/>
|
|
<meta name="docsearch:version" content="0.21.0rc0" />
|
|
|
|
|
|
</head>
|
|
|
|
|
|
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
|
|
|
|
|
|
|
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
|
|
|
<div id="pst-scroll-pixel-helper"></div>
|
|
|
|
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
|
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
|
|
|
|
|
<dialog id="pst-search-dialog">
|
|
|
|
<form class="bd-search d-flex align-items-center"
|
|
action="../../../search.html"
|
|
method="get">
|
|
<i class="fa-solid fa-magnifying-glass"></i>
|
|
<input type="search"
|
|
class="form-control"
|
|
name="q"
|
|
placeholder="Search the docs ..."
|
|
aria-label="Search the docs ..."
|
|
autocomplete="off"
|
|
autocorrect="off"
|
|
autocapitalize="off"
|
|
spellcheck="false"/>
|
|
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
|
</form>
|
|
</dialog>
|
|
|
|
<div class="pst-async-banner-revealer d-none">
|
|
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
|
</div>
|
|
|
|
|
|
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
|
<div class="bd-header__inner bd-page-width">
|
|
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
|
|
<span class="fa-solid fa-bars"></span>
|
|
</button>
|
|
|
|
|
|
<div class="col-lg-3 navbar-header-items__start">
|
|
|
|
<div class="navbar-item">
|
|
|
|
|
|
|
|
|
|
|
|
<a class="navbar-brand logo" href="../../../index.html">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
|
|
|
|
|
|
<p class="title logo__title">TensorRT-LLM</p>
|
|
|
|
</a></div>
|
|
|
|
</div>
|
|
|
|
<div class="col-lg-9 navbar-header-items">
|
|
|
|
<div class="me-auto navbar-header-items__center">
|
|
|
|
<div class="navbar-item">
|
|
|
|
|
|
<div class="version-switcher__container dropdown pst-js-only">
|
|
<button id="pst-version-switcher-button-2"
|
|
type="button"
|
|
class="version-switcher__button btn btn-sm dropdown-toggle"
|
|
data-bs-toggle="dropdown"
|
|
aria-haspopup="listbox"
|
|
aria-controls="pst-version-switcher-list-2"
|
|
aria-label="Version switcher list"
|
|
>
|
|
Choose version <!-- this text may get changed later by javascript -->
|
|
<span class="caret"></span>
|
|
</button>
|
|
<div id="pst-version-switcher-list-2"
|
|
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
|
role="listbox" aria-labelledby="pst-version-switcher-button-2">
|
|
<!-- dropdown will be populated by javascript on page load -->
|
|
</div>
|
|
</div></div>
|
|
|
|
</div>
|
|
|
|
|
|
<div class="navbar-header-items__end">
|
|
|
|
<div class="navbar-item navbar-persistent--container">
|
|
|
|
|
|
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
|
<i class="fa-solid fa-magnifying-glass"></i>
|
|
<span class="search-button__default-text">Search</span>
|
|
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
|
</button>
|
|
</div>
|
|
|
|
|
|
<div class="navbar-item">
|
|
|
|
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
|
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
|
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
|
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
|
</button></div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
|
|
<div class="navbar-persistent--mobile">
|
|
|
|
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
|
<i class="fa-solid fa-magnifying-glass"></i>
|
|
<span class="search-button__default-text">Search</span>
|
|
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
|
</button>
|
|
</div>
|
|
|
|
|
|
|
|
</div>
|
|
|
|
</header>
|
|
|
|
|
|
<div class="bd-container">
|
|
<div class="bd-container__inner bd-page-width">
|
|
|
|
|
|
|
|
<dialog id="pst-primary-sidebar-modal"></dialog>
|
|
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<a class="navbar-brand logo" href="../../../index.html">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
|
|
|
|
|
|
<p class="title logo__title">TensorRT-LLM</p>
|
|
|
|
</a>
|
|
|
|
|
|
|
|
<div class="sidebar-header-items sidebar-primary__section">
|
|
|
|
|
|
<div class="sidebar-header-items__center">
|
|
|
|
|
|
|
|
<div class="navbar-item">
|
|
|
|
|
|
<div class="version-switcher__container dropdown pst-js-only">
|
|
<button id="pst-version-switcher-button-3"
|
|
type="button"
|
|
class="version-switcher__button btn btn-sm dropdown-toggle"
|
|
data-bs-toggle="dropdown"
|
|
aria-haspopup="listbox"
|
|
aria-controls="pst-version-switcher-list-3"
|
|
aria-label="Version switcher list"
|
|
>
|
|
Choose version <!-- this text may get changed later by javascript -->
|
|
<span class="caret"></span>
|
|
</button>
|
|
<div id="pst-version-switcher-list-3"
|
|
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
|
role="listbox" aria-labelledby="pst-version-switcher-button-3">
|
|
<!-- dropdown will be populated by javascript on page load -->
|
|
</div>
|
|
</div></div>
|
|
|
|
|
|
</div>
|
|
|
|
|
|
|
|
<div class="sidebar-header-items__end">
|
|
|
|
<div class="navbar-item">
|
|
|
|
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
|
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
|
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
|
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
|
</button></div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
<div class="sidebar-primary-items__start sidebar-primary__section">
|
|
<div class="sidebar-primary-item">
|
|
|
|
|
|
|
|
<nav class="bd-docs-nav bd-links"
|
|
aria-label="Table of Contents">
|
|
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
|
|
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../key-features.html">Key Features</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../torch.html">PyTorch Backend</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../release-notes.html">Release Notes</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">LLM API</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/index.html">API Introduction</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/reference.html">API Reference</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/index.html">LLM Examples Introduction</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_medusa_decoding.html">Generate Text Using Medusa Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_eagle_decoding.html">Generate Text Using Eagle Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async.html">Generate Text Asynchronously</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_eagle2_decoding.html">Generate Text Using Eagle2 Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_kv_events.html">Get KV Cache Events</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_lookahead_decoding.html">Generate Text Using Lookahead Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_quantization.html">Generation with Quantization</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async_streaming.html">Generate Text in Streaming</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference.html">Generate text</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_customize.html">Generate text with customization</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_auto_parallel.html">Automatic Parallelism with LLM</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_llm_distributed.html">Llm Mgmn Llm Distributed</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_bench.html">Llm Mgmn Trtllm Bench</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_serve.html">Llm Mgmn Trtllm Serve</a></li>
|
|
</ul>
|
|
</details></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../examples/customization.html">LLM Common Customizations</a></li>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_medusa_decoding.html">Generate Text Using Medusa Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_eagle_decoding.html">Generate Text Using Eagle Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async.html">Generate Text Asynchronously</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_eagle2_decoding.html">Generate Text Using Eagle2 Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_kv_events.html">Get KV Cache Events</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_lookahead_decoding.html">Generate Text Using Lookahead Decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_quantization.html">Generation with Quantization</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async_streaming.html">Generate Text in Streaming</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference.html">Generate text</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_customize.html">Generate text with customization</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_auto_parallel.html">Automatic Parallelism with LLM</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_llm_distributed.html">Llm Mgmn Llm Distributed</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_bench.html">Llm Mgmn Trtllm Bench</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_serve.html">Llm Mgmn Trtllm Serve</a></li>
|
|
</ul>
|
|
</details></li>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client.html">Curl Chat Client</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_completion_client.html">Curl Completion Client</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client.html">Genai Perf Client</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client_for_multimodal.html">Genai Perf Client For Multimodal</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
|
|
</ul>
|
|
</details></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.models.html">Models</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/executor.html">Executor</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../_cpp_gen/runtime.html">Runtime</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-build.html">trtllm-build</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-serve.html">trtllm-serve</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/core-concepts.html">Model Definition</a></li>
|
|
|
|
|
|
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../architecture/add-model.html">Adding a Model</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/executor.html">Executor API</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/lora.html">Run gpt-2b + LoRA using Executor / cpp runtime</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-management.html">KV Cache Management: Pools, Blocks, and Events</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../advanced/disaggregated-service.html">Disaggregated-Service (experimental)</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-overview.html">Overview</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-benchmarking.html">Benchmarking</a></li>
|
|
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../performance/performance-tuning-guide/index.html">Performance Tuning Guide</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/benchmarking-default-performance.html">Benchmarking Default Performance</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/useful-build-time-flags.html">Useful Build-Time Flags</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.html">Tuning Max Batch Size and Max Num Tokens</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/deciding-model-sharding-strategy.html">Deciding Model Sharding Strategy</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/fp8-quantization.html">FP8 Quantization</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="../../../performance/performance-tuning-guide/useful-runtime-flags.html">Useful Runtime Options</a></li>
|
|
</ul>
|
|
</details></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../performance/perf-analysis.html">Performance Analysis</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../reference/troubleshooting.html">Troubleshooting</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../reference/support-matrix.html">Support Matrix</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../reference/precision.html">Numerical Precision</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
|
</ul>
|
|
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
|
<ul class="nav bd-sidenav">
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
|
|
</ul>
|
|
</div>
|
|
</nav></div>
|
|
</div>
|
|
|
|
|
|
<div class="sidebar-primary-items__end sidebar-primary__section">
|
|
</div>
|
|
|
|
|
|
|
|
</div>
|
|
|
|
<main id="main-content" class="bd-main" role="main">
|
|
|
|
|
|
<div class="bd-content">
|
|
<div class="bd-article-container">
|
|
|
|
<div class="bd-header-article d-print-none">
|
|
<div class="header-article-items header-article__inner">
|
|
|
|
<div class="header-article-items__start">
|
|
|
|
<div class="header-article-item">
|
|
|
|
<nav aria-label="Breadcrumb" class="d-print-none">
|
|
<ul class="bd-breadcrumbs">
|
|
|
|
<li class="breadcrumb-item breadcrumb-home">
|
|
<a href="../../../index.html" class="nav-link" aria-label="Home">
|
|
<i class="fa-solid fa-home"></i>
|
|
</a>
|
|
</li>
|
|
|
|
<li class="breadcrumb-item"><a href="../../index.html" class="nav-link">Module code</a></li>
|
|
|
|
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.runtime.generation</span></li>
|
|
</ul>
|
|
</nav>
|
|
</div>
|
|
|
|
</div>
|
|
|
|
|
|
</div>
|
|
</div>
|
|
|
|
|
|
|
|
|
|
<div id="searchbox"></div>
|
|
<article class="bd-article">
|
|
|
|
<h1>Source code for tensorrt_llm.runtime.generation</h1><div class="highlight"><pre>
|
|
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
|
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
|
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
|
<span class="c1"># You may obtain a copy of the License at</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
|
<span class="c1">#</span>
|
|
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
|
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
|
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
|
<span class="c1"># See the License for the specific language governing permissions and</span>
|
|
<span class="c1"># limitations under the License.</span>
|
|
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">copy</span>
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">platform</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">Counter</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">dataclasses</span><span class="w"> </span><span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">reduce</span><span class="p">,</span> <span class="n">wraps</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Iterable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Set</span><span class="p">,</span> <span class="n">Union</span>
|
|
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
|
|
|
|
<span class="c1"># isort: off</span>
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
|
|
<span class="c1"># isort: on</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">cuda</span><span class="w"> </span><span class="kn">import</span> <span class="n">cudart</span>
|
|
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.memory_pools.memory_pools_allocator</span><span class="w"> </span><span class="kn">import</span> \
|
|
<span class="n">MemoryPoolsAllocator</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager</span><span class="w"> </span><span class="kn">import</span> \
|
|
<span class="n">PoolsKVCacheManager</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.redrafter_utils</span><span class="w"> </span><span class="kn">import</span> <span class="o">*</span>
|
|
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">.._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">pad_vocab_size</span><span class="p">,</span> <span class="n">str_dtype_to_torch</span><span class="p">,</span> <span class="n">torch_to_numpy</span><span class="p">,</span>
|
|
<span class="n">trt_dtype_to_torch</span><span class="p">)</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..bindings</span><span class="w"> </span><span class="kn">import</span> <span class="n">KVCacheType</span><span class="p">,</span> <span class="n">ipc_nvls_allocate</span><span class="p">,</span> <span class="n">ipc_nvls_free</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..layers</span><span class="w"> </span><span class="kn">import</span> <span class="n">LanguageAdapterConfig</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..logger</span><span class="w"> </span><span class="kn">import</span> <span class="n">logger</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..lora_manager</span><span class="w"> </span><span class="kn">import</span> <span class="n">LoraManager</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..mapping</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mapping</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..plugin.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">CustomAllReduceHelper</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">..quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantMode</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">.kv_cache_manager</span><span class="w"> </span><span class="kn">import</span> <span class="n">GenerationSequence</span><span class="p">,</span> <span class="n">KVCacheUpdater</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">.session</span><span class="w"> </span><span class="kn">import</span> <span class="n">_scoped_stream</span>
|
|
|
|
<span class="c1"># When variable is set, this will disable torch.cuda.set_device(...) calls</span>
|
|
<span class="c1"># Useful in situations where device is already assigned by another library, i.e., megatron.</span>
|
|
<span class="n">DISABLE_TORCH_DEVICE_SET</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"DISABLE_TORCH_DEVICE_SET"</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
|
|
|
|
|
<div class="viewcode-block" id="decode_words_list">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.decode_words_list">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">decode_words_list</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
|
|
<span class="n">tokenizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">'''</span>
|
|
<span class="sd"> format of word_dict</span>
|
|
<span class="sd"> len(word_dict) should be same to batch_size</span>
|
|
<span class="sd"> word_dict[i] means the words for batch i</span>
|
|
<span class="sd"> len(word_dict[i]) >= 1, which means it must contain at least 1 string</span>
|
|
<span class="sd"> For example, word_dict[2] = [" I am happy", " I am sad"].</span>
|
|
<span class="sd"> '''</span>
|
|
<span class="k">assert</span> <span class="n">tokenizer</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"need to set tokenizer"</span>
|
|
|
|
<span class="n">decoded_words_batch</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
|
|
<span class="n">decoded_words_request</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">word_dict_item</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">item</span><span class="p">,</span> <span class="nb">bytes</span><span class="p">):</span>
|
|
<span class="n">item</span> <span class="o">=</span> <span class="p">[</span><span class="n">item</span><span class="o">.</span><span class="n">decode</span><span class="p">()]</span>
|
|
|
|
<span class="n">ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">item</span><span class="p">,</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="n">add_special_tokens</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">continue</span>
|
|
|
|
<span class="n">decoded_words_request</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span>
|
|
<span class="n">decoded_words_batch</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">decoded_words_request</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">decoded_words_batch</span></div>
|
|
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">to_word_list_format</span><span class="p">(</span><span class="n">word_dict</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]):</span>
|
|
<span class="w"> </span><span class="sd">'''</span>
|
|
<span class="sd"> format of word_dict</span>
|
|
<span class="sd"> len(word_dict) should be same to batch_size</span>
|
|
<span class="sd"> word_dict[i] means the words for batch i</span>
|
|
<span class="sd"> len(word_dict[i]) >= 1, which means it must contain at least 1 word</span>
|
|
<span class="sd"> For example, word_dict[2] = [[1, 267], [534]] has two words.</span>
|
|
<span class="sd"> '''</span>
|
|
|
|
<span class="n">flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="n">offsets</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">for</span> <span class="n">word_dict_item</span> <span class="ow">in</span> <span class="n">word_dict</span><span class="p">:</span>
|
|
<span class="n">items_flat_ids</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="n">items_offsets</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">word_dict_item</span><span class="p">:</span>
|
|
<span class="n">items_flat_ids</span> <span class="o">+=</span> <span class="n">ids</span>
|
|
<span class="n">items_offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">))</span>
|
|
|
|
<span class="n">flat_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">items_flat_ids</span><span class="p">))</span>
|
|
<span class="n">offsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">items_offsets</span><span class="p">)))</span>
|
|
|
|
<span class="n">pad_to</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">ids</span> <span class="ow">in</span> <span class="n">flat_ids</span><span class="p">))</span>
|
|
|
|
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">offs</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">)):</span>
|
|
<span class="n">flat_ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">offsets</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">offs</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">offs</span><span class="p">)),</span> <span class="n">constant_values</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">flat_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_input_ids</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
|
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
|
|
<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">tensors</span><span class="p">)</span>
|
|
<span class="n">row_lengths</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
|
|
<span class="n">row_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">row_lengths</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">data</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">row_lengths</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">CUASSERT</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">):</span>
|
|
<span class="n">err</span> <span class="o">=</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"CUDA ERROR: </span><span class="si">{</span><span class="n">err</span><span class="si">}</span><span class="s2">, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cuda_ret</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">cuda_ret</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_update_cuda_graph_instance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
|
|
<span class="n">err</span> <span class="o">=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecUpdate</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">graph</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">err</span> <span class="o">!=</span> <span class="n">cudart</span><span class="o">.</span><span class="n">cudaError_t</span><span class="o">.</span><span class="n">cudaSuccess</span><span class="p">:</span>
|
|
<span class="c1"># When updating cuda graph failed, destroy and instantiate one.</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExecDestroy</span><span class="p">(</span><span class="n">instance</span><span class="p">))</span>
|
|
<span class="n">instance</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">return</span> <span class="n">instance</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">pad_id</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
|
<span class="n">is_pad_id_in_inputs</span> <span class="o">=</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">pad_id</span> <span class="ow">in</span> <span class="n">input_ids</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">input_ids</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">is_pad_id_in_inputs</span><span class="p">:</span>
|
|
<span class="n">mask</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="c1"># for enc-dec models, pad_id could be the start token and should be always counted</span>
|
|
<span class="c1"># as valid token rather than padded token, so we force its mask to be 1.</span>
|
|
<span class="c1"># This doesn't impact the existing behavior</span>
|
|
<span class="n">mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="k">return</span> <span class="n">mask</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
|
|
|
|
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
|
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
|
<span class="k">return</span> <span class="n">new_tensor</span>
|
|
|
|
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">_Profiler</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">IProfiler</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">report_layer_time</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer_name</span><span class="p">,</span> <span class="n">ms</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">results</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">layer_name</span><span class="p">,</span> <span class="n">ms</span><span class="p">))</span>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_contiguous_tile_beam_width</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">new_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*=</span> <span class="n">num_beams</span>
|
|
|
|
<span class="n">numel</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">num_beams</span> <span class="o">*</span> <span class="n">numel</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Take the first 'size' values to tile and skip the others.</span>
|
|
<span class="n">vals</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">size</span><span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_beams</span><span class="p">):</span>
|
|
<span class="n">new_tensor</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span><span class="p">:(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">size</span><span class="p">]</span> <span class="o">=</span> <span class="n">vals</span>
|
|
|
|
<span class="k">return</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
|
|
|
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">_Runtime</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">runtime</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span>
|
|
<span class="n">engine</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ICudaEngine</span>
|
|
<span class="n">ctx_context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
|
<span class="n">context_0</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
|
<span class="n">context_1</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span>
|
|
<span class="n">profiler</span><span class="p">:</span> <span class="n">_Profiler</span>
|
|
<span class="n">engine_inspector</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">EngineInspector</span>
|
|
<span class="n">cuda_graph_instances</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphExec_t</span><span class="p">]</span>
|
|
<span class="n">input_tensor_names</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
|
|
<span class="n">output_tensor_names</span><span class="p">:</span> <span class="n">Set</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare</span><span class="p">(</span><span class="n">mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_serialize_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">__create_and_setup_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">address</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">profile_idx</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">)</span> <span class="o">-></span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">:</span>
|
|
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context_without_device_memory</span><span class="p">()</span>
|
|
<span class="k">assert</span> <span class="n">context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Failed to create an execution context with the provided device memory!"</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_device_memory</span><span class="p">(</span><span class="n">address</span><span class="p">,</span> <span class="n">size</span><span class="p">)</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_optimization_profile_async</span><span class="p">(</span><span class="n">profile_idx</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="c1"># If nvtx verbosity is DETAILED, change it to LAYER_NAMES_ONLY for inference performance</span>
|
|
<span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">nvtx_verbosity</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">DETAILED</span><span class="p">:</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">nvtx_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">LAYER_NAMES_ONLY</span>
|
|
<span class="k">return</span> <span class="n">context</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_set_profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">return</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="n">_Profiler</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span><span class="o">.</span><span class="n">enqueue_emits_profile</span> <span class="o">=</span> <span class="kc">False</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">__prepare</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">=</span> <span class="n">mapping</span><span class="o">.</span><span class="n">rank</span>
|
|
<span class="n">local_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime_rank</span> <span class="o">%</span> <span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
|
|
<span class="k">if</span> <span class="n">DISABLE_TORCH_DEVICE_SET</span><span class="p">:</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_device</span><span class="p">()))</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">local_rank</span><span class="p">)</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaSetDevice</span><span class="p">(</span><span class="n">local_rank</span><span class="p">))</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
|
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">OUTPUT</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">engine_inspector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">create_engine_inspector</span><span class="p">()</span>
|
|
<span class="c1"># cuda graph ping-pong instances</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span><span class="p">:</span>
|
|
<span class="c1"># engine does not have weight streaming enabled</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare_execution_contexts</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># avoid OOM when print engine info</span>
|
|
|
|
<span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="s2">"verbose"</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">__print_engine_info</span><span class="p">()</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">__prepare_execution_contexts</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="c1"># The device_memory_size_v2 stores the memory required by the largest profile.</span>
|
|
<span class="c1"># When weight streaming is enable, it must be queried after the weight streaming budget set.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">))</span>
|
|
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
|
|
<span class="n">address</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaMalloc</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="o">=</span> <span class="n">address</span>
|
|
|
|
<span class="k">with</span> <span class="n">_scoped_stream</span><span class="p">()</span> <span class="k">as</span> <span class="n">stream</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="c1"># At step = 0, context_1 is active</span>
|
|
<span class="c1"># At step = 1, context_0 is active</span>
|
|
<span class="c1"># At step = 2, context_1 is active</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_1</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="c1"># At step = 0, ctx_context is active</span>
|
|
<span class="c1"># At step = 1, context_0 is active</span>
|
|
<span class="c1"># At step = 2, context_1 is active</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__create_and_setup_context</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device_memory_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Number of optimization profiles: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
|
|
<span class="s2">"Python runtime only support 1 or 2 optimization profiles, "</span>
|
|
<span class="s2">"set --multiple_profiles=disable when calling trtllm-build "</span>
|
|
<span class="s2">"to disable the feature."</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">__print_engine_info</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span>
|
|
<span class="n">context</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">create_execution_context</span><span class="p">(</span>
|
|
<span class="n">trt</span><span class="o">.</span><span class="n">ExecutionContextAllocationStrategy</span><span class="o">.</span><span class="n">USER_MANAGED</span><span class="p">)</span>
|
|
<span class="n">n_op</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">num_optimization_profiles</span>
|
|
<span class="n">max_name_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Name</span>
|
|
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Shape</span>
|
|
<span class="n">tensor_name_list</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="c1"># Get information of engine input / output</span>
|
|
<span class="n">tid</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Tensor Information Dictionary</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
|
|
<span class="n">item</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
|
<span class="n">max_name_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_name_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"mode"</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'I'</span> <span class="k">if</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span> <span class="k">else</span> <span class="s1">'O'</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"location"</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'GPU'</span> <span class="k">if</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_location</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">)</span> <span class="k">else</span> <span class="s1">'CPU'</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"data_type"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))[</span><span class="mi">9</span><span class="p">:]</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"build_shape"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"profile_list"</span><span class="p">]</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">)]</span>
|
|
<span class="k">if</span> <span class="n">item</span><span class="p">[</span><span class="s2">"mode"</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"I"</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="n">item</span><span class="p">[</span><span class="s2">"location"</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"GPU"</span><span class="p">:</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_value</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
|
|
<span class="n">item</span><span class="p">[</span><span class="s2">"profile_list"</span><span class="p">][</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_shape_width</span><span class="p">,</span>
|
|
<span class="o">*</span><span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">s</span><span class="p">))</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">])</span>
|
|
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">item</span>
|
|
<span class="c1"># Set input shape to get output shape</span>
|
|
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
|
|
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="c1"># Min, Opt, Max</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tid</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"mode"</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"I"</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"location"</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"GPU"</span><span class="p">:</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">,</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"profile_list"</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">])</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"profile_list"</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"mode"</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"O"</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">context</span><span class="o">.</span><span class="n">all_binding_shapes_specified</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">all_shape_inputs_specified</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">][</span><span class="s2">"profile_list"</span><span class="p">][</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Print information of engine input / output</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s2">"Information of engine input / output."</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'Name'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|I/O|Location|DataType|</span><span class="si">{</span><span class="s1">'Shape'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'-'</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
|
|
<span class="n">item</span> <span class="o">=</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
|
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2"><</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">'mode'</span><span class="p">]</span><span class="si">:</span><span class="s2">^3s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">'location'</span><span class="p">]</span><span class="si">:</span><span class="s2">^8s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">'data_type'</span><span class="p">]</span><span class="si">:</span><span class="s2">^8s</span><span class="si">}</span><span class="s2">|"</span>
|
|
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">item</span><span class="p">[</span><span class="s1">'build_shape'</span><span class="p">]</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">24</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="c1"># Print information of optimization profile</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s2">"Information of optimization profile."</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_op</span><span class="p">):</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Optimization Profile </span><span class="si">{</span><span class="n">k</span><span class="si">}</span><span class="s2">:"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'Name'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">'Min'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">'Opt'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">'Max'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'-'</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">tensor_name_list</span><span class="p">:</span>
|
|
<span class="n">item</span> <span class="o">=</span> <span class="n">tid</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
|
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2"><</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">'profile_list'</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">'profile_list'</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s1">'profile_list'</span><span class="p">][</span><span class="n">k</span><span class="p">][</span><span class="mi">2</span><span class="p">])</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">4</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">print_context_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">context_index</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">n_io</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span>
|
|
<span class="n">max_name_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Name</span>
|
|
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Maximum Width of tensor Shape</span>
|
|
<span class="n">tensorInfo</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_io</span><span class="p">):</span>
|
|
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="n">b_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="n">tensorInfo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">name</span><span class="p">,</span> <span class="n">b_input</span><span class="p">,</span> <span class="n">shape</span><span class="p">]</span>
|
|
<span class="n">max_name_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_name_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="n">max_shape_width</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_shape_width</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
|
|
<span class="c1"># Shape input tensor is not used in TRT-LLM yet</span>
|
|
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Information of context input / output."</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Using Optimization Profile: </span><span class="si">{</span><span class="n">context_index</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'Name'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|I/O|</span><span class="si">{</span><span class="s1">'Shape'</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'-'</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_io</span><span class="p">):</span>
|
|
<span class="n">name</span><span class="p">,</span> <span class="n">b_input</span><span class="p">,</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">tensorInfo</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">info</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">:</span><span class="s2"><</span><span class="si">{</span><span class="n">max_name_width</span><span class="si">}}</span><span class="s2">|</span><span class="si">{</span><span class="s1">'I'</span><span class="w"> </span><span class="k">if</span><span class="w"> </span><span class="n">b_input</span><span class="w"> </span><span class="k">else</span><span class="w"> </span><span class="s1">'O'</span><span class="si">:</span><span class="s2">^3s</span><span class="si">}</span><span class="s2">|</span><span class="si">{</span><span class="n">shape</span><span class="si">:</span><span class="s2">^</span><span class="si">{</span><span class="n">max_shape_width</span><span class="si">}}</span><span class="s2">|"</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">info</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="s1">'='</span><span class="o">*</span><span class="p">(</span><span class="n">max_name_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">max_shape_width</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">6</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_set_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
|
<span class="n">shape_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
|
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">shape_dict</span><span class="p">:</span>
|
|
<span class="c1"># shape and buffer can be set by calling _set_tensors API</span>
|
|
<span class="k">continue</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_mode</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">TensorIOMode</span><span class="o">.</span><span class="n">INPUT</span><span class="p">:</span>
|
|
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">])</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"setting input tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2"> and type </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Couldn't assign </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">shape_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="si">}</span><span class="s2">, "</span>
|
|
<span class="sa">f</span><span class="s2">"engine supports [min, opt, max] = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_profile_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span><span class="w"> </span><span class="n">context</span><span class="o">.</span><span class="n">active_optimization_profile</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_set_buffer</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
|
<span class="n">buffer_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
|
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">buffer_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">is_contiguous</span><span class="p">(</span>
|
|
<span class="p">),</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is not contiguous()"</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">buffer_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">())</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_set_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
|
<span class="n">tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="s2">"RuntimeTensor"</span><span class="p">]):</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_tensor_names</span><span class="p">:</span>
|
|
<span class="c1"># it's allowed to call set_tensors multi times with different tensors</span>
|
|
<span class="c1"># each time only set some of the engine tensors, so it is valid to skip the ones not in the current given tensors dict</span>
|
|
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
|
|
<span class="k">continue</span>
|
|
|
|
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="o">!=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">data</span><span class="p">:</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">list</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span> <span class="o">!=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_input_shape</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_tensor_names</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">))</span>
|
|
<span class="n">t</span> <span class="o">=</span> <span class="n">tensors</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
|
<span class="c1"># output's shape is inference by TRT, no need to set the shape here</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">set_tensor_address</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_set_weight_streaming</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gpu_weights_percent</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">gpu_weights_percent</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Engine built without weight streaming. Cannot set gpu_weights_percent to a value other than 1."</span>
|
|
<span class="k">return</span>
|
|
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_0</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">context_1</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ctx_context</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="nb">min</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="nb">max</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">streamable_weights_size</span>
|
|
<span class="n">budget</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">gpu_weights_percent</span> <span class="o">*</span> <span class="nb">max</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">=</span> <span class="n">budget</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">weight_streaming_budget_v2</span> <span class="o">==</span> <span class="n">budget</span><span class="p">,</span> <span class="s2">"Failed to set weight streaming budget!"</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Set gpu weights percent to </span><span class="si">{</span><span class="n">gpu_weights_percent</span><span class="si">}</span><span class="s2">, which is </span><span class="si">{</span><span class="n">budget</span><span class="si">}</span><span class="s2"> bytes. Valid range: </span><span class="si">{</span><span class="nb">min</span><span class="si">}</span><span class="s2"> bytes ~ </span><span class="si">{</span><span class="nb">max</span><span class="si">}</span><span class="s2"> bytes."</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">try</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">__prepare_execution_contexts</span><span class="p">()</span>
|
|
<span class="k">except</span><span class="p">:</span>
|
|
<span class="n">free_mem</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">mem_get_info</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">free_mem</span> <span class="o"><</span> <span class="n">budget</span><span class="p">:</span>
|
|
<span class="nb">print</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Failed to create context. Possibly out of memory: Memory budget is </span><span class="si">{</span><span class="n">budget</span><span class="si">}</span><span class="s2"> bytes but only </span><span class="si">{</span><span class="n">free_mem</span><span class="si">}</span><span class="s2"> bytes are available on the GPU."</span>
|
|
<span class="p">)</span>
|
|
<span class="k">raise</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_check_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">tensors</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">):</span>
|
|
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="n">ptr</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">get_tensor_address</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">ptr</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Engine I/O tensor </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is unbound"</span><span class="p">)</span>
|
|
<span class="n">shp</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">get_tensor_shape</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">s</span> <span class="o"><</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">shp</span><span class="p">]):</span> <span class="c1"># skip if shape is not available</span>
|
|
<span class="k">continue</span>
|
|
<span class="n">dt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
|
|
<span class="n">tdt</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="n">dt</span><span class="p">)</span>
|
|
<span class="n">sz</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tdt</span><span class="p">)</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">shp</span><span class="p">)</span>
|
|
<span class="n">tensors</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">ptr</span><span class="p">,</span> <span class="n">ptr</span> <span class="o">+</span> <span class="n">sz</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shp</span><span class="p">,</span> <span class="n">sz</span><span class="p">))</span>
|
|
<span class="n">tensors</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span> <span class="c1"># sort by start address</span>
|
|
<span class="n">starts</span><span class="p">,</span> <span class="n">ends</span><span class="p">,</span> <span class="n">names</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)</span>
|
|
<span class="n">starts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">starts</span><span class="p">)</span>
|
|
<span class="n">ends</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">ends</span><span class="p">)</span>
|
|
<span class="n">overalps</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">((</span><span class="n">starts</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o"><</span> <span class="n">ends</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">())</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">overalps</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="c1"># unsqueeze if there is a single value so it became scalar</span>
|
|
<span class="n">overalps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">overalps</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">overalps</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">overalps</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">overalps</span><span class="p">):</span>
|
|
<span class="n">left_name</span> <span class="o">=</span> <span class="n">names</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">right_name</span> <span class="o">=</span> <span class="n">names</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="s2">"key_value"</span> <span class="ow">in</span> <span class="n">left_name</span> <span class="ow">and</span> <span class="s2">"key_value"</span> <span class="ow">in</span> <span class="n">right_name</span><span class="p">:</span> <span class="c1"># kv</span>
|
|
<span class="n">left_names</span> <span class="o">=</span> <span class="n">left_name</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"_"</span><span class="p">)</span>
|
|
<span class="n">right_names</span> <span class="o">=</span> <span class="n">right_name</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"_"</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">left_names</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">right_names</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> <span class="c1"># same kv layer</span>
|
|
<span class="k">assert</span> <span class="p">(</span><span class="n">left_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"past"</span> <span class="ow">and</span> <span class="n">right_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"present"</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
|
|
<span class="n">left_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"present"</span> <span class="ow">and</span> <span class="n">right_names</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"past"</span><span class="p">),</span> \
|
|
<span class="sa">f</span><span class="s2">"Overlap found between </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="k">continue</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"TENSOR BUFFER OVERLAP DETECTED: </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2"> and </span><span class="si">{</span><span class="n">tensors</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> !!!"</span>
|
|
<span class="p">)</span>
|
|
<span class="k">return</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_insert_step_to_profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Profiler is disable"</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">profiler</span><span class="o">.</span><span class="n">results</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="s2">"step"</span><span class="p">,</span> <span class="n">step</span><span class="p">))</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_is_profiling</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_run</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">context</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IExecutionContext</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
|
|
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">):</span>
|
|
<span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span><span class="o">.</span><span class="n">cuda_stream</span>
|
|
<span class="n">ok</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">ok</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">try</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">address</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaFree</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">address</span><span class="p">)</span>
|
|
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
|
|
<span class="k">pass</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">context_mem_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">device_memory_size_v2</span>
|
|
|
|
|
|
<div class="viewcode-block" id="ModelConfig">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ModelConfig">[docs]</a>
|
|
<span class="nd">@dataclass</span>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">ModelConfig</span><span class="p">:</span>
|
|
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">gpt_attention_plugin</span><span class="p">:</span> <span class="nb">bool</span>
|
|
<span class="n">gemm_allreduce_plugin</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
|
<span class="n">kv_cache_type</span><span class="p">:</span> <span class="n">KVCacheType</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">CONTINUOUS</span>
|
|
<span class="n">cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">has_position_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
|
<span class="n">has_token_type_embedding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
|
|
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
|
<span class="n">lora_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">lora_target_modules</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
|
|
<span class="n">trtllm_modules_to_hf_modules</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">skip_cross_kv</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="n">num_medusa_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">max_medusa_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">paged_state</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
|
<span class="n">mamba_conv1d_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
|
<span class="n">conv_kernel</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">layer_types</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">list</span><span class="p">)</span>
|
|
<span class="n">rnn_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">rnn_head_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">rnn_conv_dim_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">state_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">state_dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">""</span>
|
|
<span class="n">gpu_weights_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span>
|
|
<span class="c1"># ReDrafter</span>
|
|
<span class="n">redrafter_num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">redrafter_draft_len_per_beam</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">num_kv_heads_per_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">num_kv_heads_per_cross_attn_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
|
<span class="c1"># language adapter</span>
|
|
<span class="n">language_adapter_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LanguageAdapterConfig</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span></div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="SamplingConfig">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig">[docs]</a>
|
|
<span class="nd">@dataclass</span>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">SamplingConfig</span><span class="p">:</span>
|
|
<span class="n">end_id</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">pad_id</span><span class="p">:</span> <span class="nb">int</span>
|
|
|
|
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
|
|
<span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">num_return_sequences</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">stop_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">bad_words_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
|
|
<span class="n">temperature</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
|
<span class="n">top_k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">top_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
|
<span class="n">top_p_decay</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
|
|
<span class="n">top_p_min</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># float</span>
|
|
<span class="n">top_p_reset_ids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="c1"># int</span>
|
|
<span class="n">random_seed</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
|
|
<span class="n">length_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
|
<span class="n">early_stopping</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">repetition_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
|
|
<span class="n">min_length</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">presence_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
|
<span class="n">frequency_penalty</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
|
<span class="n">use_beam_hyps</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<span class="c1"># None here means user didn't set it, and dynamicDecodeOp.cpp take optional value</span>
|
|
<span class="c1"># The real default value is set in dynamicDecodeOp.cpp when it's None</span>
|
|
<span class="n">beam_search_diversity_rate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
|
<span class="n">output_cum_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">output_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">no_repeat_ngram_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">min_p</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="SamplingConfig.update">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.SamplingConfig.update">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">unused_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
|
|
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">unused_kwargs</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
|
|
<span class="k">return</span> <span class="n">unused_kwargs</span></div>
|
|
</div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="LogitsProcessor">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessor">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">LogitsProcessor</span><span class="p">:</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Base class for all logit processors that can be applied during generation.</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2"> is an abstract class. Only classes inheriting this class can be called."</span>
|
|
<span class="p">)</span></div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="LogitsProcessorList">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.LogitsProcessorList">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">LogitsProcessorList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">LogitsProcessor</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">processor</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">:</span>
|
|
<span class="n">scores</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">scores</span></div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="StoppingCriteria">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteria">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">StoppingCriteria</span><span class="p">:</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> Base class for all stopping criteria that can be applied during generation.</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">"StoppingCriteria needs to be subclassed"</span><span class="p">)</span></div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="StoppingCriteriaList">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.StoppingCriteriaList">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">StoppingCriteriaList</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">StoppingCriteria</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">criteria</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">scores</span><span class="p">)</span> <span class="k">for</span> <span class="n">criteria</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">)</span></div>
|
|
|
|
|
|
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">RuntimeTensor</span><span class="p">:</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">""</span>
|
|
<span class="c1"># shape is the one sent to TRT, the actual torch tensor can be larger than the shape</span>
|
|
<span class="c1"># this is useful when allocating a big KV cache tensor at the beginning and incremental seq length dim of TRT engine's input tensor</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="c1"># Used when pointer specified</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="nd">@staticmethod</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span>
|
|
<span class="n">str_dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'RuntimeTensor'</span><span class="p">:</span>
|
|
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="n">pointer</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">shape</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="n">str_dtype</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">t</span>
|
|
|
|
<span class="nd">@staticmethod</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">from_torch</span><span class="p">(</span>
|
|
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
|
<span class="n">data</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">override_shape</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Iterable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'RuntimeTensor'</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)),</span> <span class="sa">f</span><span class="s2">"data </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> is </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">data</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="n">t</span> <span class="o">=</span> <span class="n">RuntimeTensor</span><span class="p">()</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
|
|
<span class="c1"># need to hold the torch tensor for memory life time</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_dtype</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">dtype</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_data_ptr</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">_torch_tensor</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()</span>
|
|
<span class="n">torch_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>
|
|
<span class="k">if</span> <span class="n">override_shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">override_shape</span>
|
|
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">override_shape</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
|
|
<span class="n">override_shape</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">>=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">override_shape</span>
|
|
<span class="p">]),</span> <span class="sa">f</span><span class="s2">"Expect all dimensions >=0, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">"</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">volume_func</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">override_shape</span><span class="p">)</span> <span class="o"><=</span> <span class="n">volume_func</span><span class="p">(</span><span class="n">torch_shape</span><span class="p">),</span> \
|
|
<span class="sa">f</span><span class="s2">"Override the shape to be larger than the underlying torch Tensor, got </span><span class="si">{</span><span class="n">override_shape</span><span class="si">}</span><span class="s2">, torch tensor shape </span><span class="si">{</span><span class="n">torch_shape</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">_shape</span> <span class="o">=</span> <span class="n">torch_shape</span>
|
|
<span class="k">return</span> <span class="n">t</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">to_torch</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
|
<span class="s1">'RuntimeTensor cannot be converted to torch tensor as constructed from pointer'</span>
|
|
<span class="p">)</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_tensor</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_shape</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_data_ptr</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dtype</span>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">GenerationSession</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
|
|
|
<span class="n">_model_config</span><span class="p">:</span> <span class="n">ModelConfig</span>
|
|
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span>
|
|
<span class="n">runtime</span><span class="p">:</span> <span class="n">_Runtime</span>
|
|
<span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span>
|
|
<span class="n">buffer_allocated</span><span class="p">:</span> <span class="nb">bool</span>
|
|
<span class="n">debug_mode</span><span class="p">:</span> <span class="nb">bool</span>
|
|
<span class="n">quant_mode</span><span class="p">:</span> <span class="n">QuantMode</span>
|
|
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span>
|
|
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="p">:</span> <span class="kc">None</span>
|
|
<span class="n">num_draft_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="n">medusa_topks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">medusa_paths</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">medusa_tree_ids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">medusa_position_offsets</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">medusa_temperature</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
|
<span class="n">engine_buffer</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
|
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">ModelConfig</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span> <span class="o">=</span> <span class="n">model_config</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span> <span class="o">=</span> <span class="n">_Runtime</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">,</span> <span class="n">mapping</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">DISABLE_TORCH_DEVICE_SET</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="sa">f</span><span class="s1">'cuda:</span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_device</span><span class="p">()</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s1">'cuda:</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">runtime_rank</span><span class="w"> </span><span class="o">%</span><span class="w"> </span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="c1"># dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="o">=</span> <span class="n">debug_mode</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="n">debug_tensors_to_save</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="n">cuda_graph_mode</span>
|
|
<span class="c1"># Optional inputs for dynamic decoder</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="c1"># TODO: in tensorrt_llm/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp it's T, can be float or half?</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">False</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span> <span class="o">=</span> <span class="n">pad_vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'attention'</span><span class="p">]</span> <span class="o">*</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">layer_types</span>
|
|
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span> <span class="o">*</span> <span class="p">(</span><span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span>
|
|
<span class="nb">len</span><span class="p">(</span><span class="n">layer_types</span><span class="p">))</span>
|
|
<span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span> <span class="o">+</span> <span class="n">layer_types</span><span class="p">[</span><span class="mi">0</span><span class="p">:(</span><span class="n">model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span>
|
|
<span class="nb">len</span><span class="p">(</span><span class="n">layer_types</span><span class="p">))]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span> <span class="o">=</span> <span class="n">layer_types</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">=</span> \
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">'attention'</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">></span> <span class="mi">0</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span> <span class="o">=</span> <span class="s1">'recurrent'</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="n">attn_layer_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="n">attn_layer_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">attn_layer_idx</span>
|
|
<span class="n">attn_layer_idx</span> <span class="o">+=</span> <span class="mi">1</span>
|
|
|
|
<span class="c1"># Cyclic KV cache buffer names.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
|
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"The paged KV cache in Python runtime is experimental. For performance and correctness, please, use C++ runtime."</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">NcclCommunicatorOp</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">]:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"Logits dtype not supported by decoder. Falling back to float32. You may want to change the logits dtype to float16 in your model definition."</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">classes</span><span class="o">.</span><span class="n">trtllm</span><span class="o">.</span><span class="n">DynamicDecodeOp</span><span class="p">(</span>
|
|
<span class="n">model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
|
|
|
<span class="n">expected_tensor_names</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">ipc_buffers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span> <span class="o">=</span> <span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">allocate_workspace</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
|
|
<span class="n">CustomAllReduceHelper</span><span class="o">.</span><span class="n">max_workspace_size_auto</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">))</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">gather_tree</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'input_ids'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'hidden_states_input'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'last_token_ids'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'hidden_states_output'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
|
|
<span class="p">):</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">(</span>
|
|
<span class="p">):</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'token_type_ids'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'cache_indirection'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'kv_cache_block_offsets'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_kv_cache_block_offsets'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_kv_cache_pool_pointers'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_kv_cache_pool_mapping'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'cross_kv_cache_block_offsets'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_cross_kv_cache_block_offsets'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_pointers'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_mapping'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'cross_attention_mask'</span><span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'cross_attention_packed_mask'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># Refer to gpt_attention() inside functional.py</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'cross_attention_mask'</span><span class="p">,</span>
|
|
<span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'cross_attention_packed_mask'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'cross_attention_mask'</span><span class="p">,</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'recurrent'</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'slot_mapping'</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'recurrent'</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'past_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'sequence_length'</span><span class="p">,</span> <span class="s1">'host_past_key_value_lengths'</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'context_lengths'</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">,</span>
|
|
<span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="s1">'host_runtime_perf_knobs'</span><span class="p">,</span>
|
|
<span class="s1">'host_context_progress'</span>
|
|
<span class="p">]</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'host_request_types'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'prompt_embedding_table'</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'encoder_output'</span><span class="p">,</span>
|
|
<span class="s1">'encoder_input_lengths'</span><span class="p">,</span>
|
|
<span class="s1">'encoder_max_input_length'</span><span class="p">,</span>
|
|
<span class="s1">'cross_kv_cache_gen'</span><span class="p">,</span>
|
|
<span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'skip_cross_attn_blocks'</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">skip_cross_kv</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'cross_kv_reuse'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'all_reduce_workspace'</span><span class="p">]</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span> <span class="o">=</span> <span class="n">LoraManager</span><span class="o">.</span><span class="n">get_missing_qkv_modules</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">lora_plugin</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="p">]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'host_encoder_input_lengths'</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span>
|
|
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">,</span>
|
|
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">,</span> <span class="s1">'spec_decoding_packed_mask'</span><span class="p">,</span>
|
|
<span class="s1">'spec_decoding_use'</span><span class="p">,</span> <span class="s1">'medusa_logits'</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="n">get_redrafter_tensor_names</span><span class="p">()</span>
|
|
|
|
<span class="c1"># language adapter</span>
|
|
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">language_adapter_config</span><span class="p">:</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="s1">'language_adapter_routings'</span><span class="p">]</span>
|
|
|
|
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"allreduce_ub_"</span><span class="p">)</span> <span class="ow">or</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span>
|
|
<span class="s2">"gemm_allreduce"</span><span class="p">):</span>
|
|
<span class="n">expected_tensor_names</span> <span class="o">+=</span> <span class="p">[</span><span class="n">name</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span>
|
|
<span class="n">found_tensor_names</span><span class="p">):</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"The following expected tensors are not found: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"Those tensors in engine are not expected: </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Expected tensor names: </span><span class="si">{</span><span class="n">expected_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Found tensor names: </span><span class="si">{</span><span class="n">found_tensor_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
|
<span class="s2">"Tensor names in engine are not the same as expected, to use this GenerationSession, "</span>
|
|
<span class="s2">"you need to use PretrainedModel.prepare_inputs to create TRT Network inputs."</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
|
<span class="nb">set</span><span class="p">(</span><span class="n">found_tensor_names</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">expected_tensor_names</span><span class="p">))</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Debug tensors found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Debug tensors to save: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">try</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="n">ipc_nvls_free</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="p">)</span>
|
|
<span class="k">except</span> <span class="ne">TypeError</span><span class="p">:</span>
|
|
<span class="k">pass</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">context_mem_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_mem_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">vocab_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">vocab_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">num_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
|
|
<span class="sa">f</span><span class="s2">"num_layers </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="si">}</span><span class="s2"> must be a multiple of pipeline parallelism size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">last_layer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">num_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_heads</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="c1"># For linear layer in attention block</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">hidden_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">use_gpt_attention_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">use_mamba_conv1d_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">mamba_conv1d_plugin</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">paged_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">kv_cache_type</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">use_kv_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">DISABLED</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">tokens_per_block</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">tokens_per_block</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">remove_input_padding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
|
|
|
<div class="viewcode-block" id="GenerationSession.get_num_heads_kv">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_num_heads_kv">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">get_num_heads_kv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">layer_idx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
|
|
<span class="n">layer_idx</span><span class="p">]</span> <span class="o">==</span> <span class="s2">"attention"</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Layer </span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2"> is not an attention layer"</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">]</span>
|
|
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads</span></div>
|
|
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">head_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">max_prompt_embedding_table_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">quant_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">quant_mode</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">gather_context_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_context_logits</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">gather_generation_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gather_generation_logits</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">profiler</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">profiler</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">engine_inspector</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine_inspector</span>
|
|
|
|
<div class="viewcode-block" id="GenerationSession.cuda_stream_guard">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.cuda_stream_guard">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">cuda_stream_guard</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""Sync external stream and set current stream to the one bound to the session. Reset on exit.</span>
|
|
<span class="sd"> """</span>
|
|
|
|
<span class="nd">@wraps</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">external_stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
|
|
<span class="n">external_stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
|
|
<span class="n">ret</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">external_stream</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_stream</span><span class="p">(</span><span class="n">external_stream</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">ret</span>
|
|
|
|
<span class="k">return</span> <span class="n">wrapper</span></div>
|
|
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">cross_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">cross_attention</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">has_position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_position_embedding</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">has_token_type_embedding</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">has_token_type_embedding</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">use_lora_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">use_gemm_allreduce_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="nb">bool</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_allreduce_plugin</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">is_medusa_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">></span> <span class="mi">0</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">is_redrafter_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span> <span class="o">></span> <span class="mi">0</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">max_draft_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_medusa_tokens</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">num_medusa_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_medusa_heads</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">paged_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">paged_state</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">conv_kernel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">conv_kernel</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_hidden_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_hidden_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_head_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_head_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">rnn_conv_dim_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">state_size</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_size</span>
|
|
|
|
<span class="nd">@property</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">state_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_dtype</span> <span class="o">==</span> <span class="s2">""</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">state_dtype</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_capture_cuda_graph_and_instantiate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
|
|
<span class="n">instance_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="c1"># Create two cuda graph once.If cuda graph has already existed, skip it.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">return</span>
|
|
<span class="c1"># capture cuda graph</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamBeginCapture</span><span class="p">(</span>
|
|
<span class="n">stream</span><span class="p">,</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamCaptureMode</span><span class="o">.</span><span class="n">cudaStreamCaptureModeGlobal</span><span class="p">))</span>
|
|
<span class="n">context</span><span class="o">.</span><span class="n">execute_async_v3</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
|
|
<span class="n">next_graph</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span><span class="n">cudart</span><span class="o">.</span><span class="n">cudaStreamEndCapture</span><span class="p">(</span><span class="n">stream</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
|
|
<span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_update_cuda_graph_instance</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">next_graph</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">CUASSERT</span><span class="p">(</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphInstantiate</span><span class="p">(</span><span class="n">next_graph</span><span class="p">,</span> <span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
|
|
|
|
<span class="c1"># Pre-upload cuda graph to stream</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphUpload</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">__setup_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">'''Allocate buffers and setup the post-processing decoder kernel</span>
|
|
<span class="sd"> '''</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span> <span class="c1"># just to make a shorter name, no other meaning</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_k.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_k.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.top_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.temperature.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.temperature.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.repetition_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.repetition_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.length_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.length_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.early_stopping.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.early_stopping.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">early_stopping</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.presence_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.presence_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.frequency_penalty.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.frequency_penalty.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_length.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_length.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.beam_search_diversity_rate.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.beam_search_diversity_rate.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.random_seed.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int64"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.random_seed.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.no_repeat_ngram_size.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.int32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.no_repeat_ngram_size.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_p.dtype (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) must be torch.float32"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">batch_size</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"scfg.min_p.shape[0] (</span><span class="si">{</span><span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">) must equal to batch_size (</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">)"</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">repetition_penalty</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">presence_penalty</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">frequency_penalty</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_length</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_length_penalty</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_early_stopping</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_search_diversity_rate</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">random_seed</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_decay</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_min</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">top_p_reset_ids</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">no_repeat_ngram_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">min_p</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"end_id cannot be none"</span>
|
|
<span class="k">assert</span> <span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">'pad_id cannot be none'</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">max</span><span class="p">()</span>
|
|
|
|
<span class="c1"># setup output ids buffer</span>
|
|
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="c1"># input_ids only have one dimension, which means remove_padding is enabled</span>
|
|
<span class="n">split_ids_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span>
|
|
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">split_ids_list</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">),</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">padded_input_ids</span> <span class="o">=</span> <span class="n">input_ids</span>
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">)</span>
|
|
<span class="n">tiled_input_ids</span> <span class="o">=</span> <span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">)</span>
|
|
<span class="n">tiled_input_ids</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># TODO: delete?</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">tiled_input_ids</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
|
|
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">padded_input_ids</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">-</span> <span class="n">max_context_length</span><span class="p">),</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">padded_input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)),</span>
|
|
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Note: we still allocate max_seq_length size of parent ids (not max_attention_window_size).</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span> <span class="o">+</span> <span class="mi">1</span>
|
|
<span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s2">"redrafter_inverted_temperature"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">reciprocal</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">temperature</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
|
<span class="o">-</span><span class="mf">1e20</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.0</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">(</span>
|
|
<span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
|
|
<span class="n">fill_value</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_tensor_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
|
<span class="c1"># return torch dtype given tensor name for convenience</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
|
|
<span class="k">return</span> <span class="n">dtype</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_init_medusa</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
|
|
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.runtime.medusa_utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">_medusa_setup</span><span class="p">,</span>
|
|
<span class="n">expand_choices_if_needed</span><span class="p">)</span>
|
|
<span class="n">medusa_choices</span> <span class="o">=</span> <span class="n">expand_choices_if_needed</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span>
|
|
<span class="n">medusa_info</span> <span class="o">=</span> <span class="n">_medusa_setup</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_topks</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">[</span><span class="mi">1</span><span class="p">:,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span>
|
|
<span class="p">)</span> <span class="c1"># convert to bool, original mask includes true token as well</span>
|
|
|
|
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
|
|
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
|
|
<span class="c1"># Note: spec_decoding_packed_mask has no paddings in the first dimension.</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_packed_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
|
|
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">target_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_spec_decoding_use</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_paths</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_tree_ids</span>
|
|
|
|
<span class="c1"># Expand medusa position offsets to number of batch size in order to be compatible with the new Medusa.</span>
|
|
<span class="n">target_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
|
<span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">target_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span>
|
|
<span class="c1"># Note: medusa_position_offsets still keeps the paddings in order to get max_gen_input_length from the shape info.</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span> <span class="o">=</span> <span class="n">medusa_info</span><span class="o">.</span><span class="n">medusa_position_offsets</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span>
|
|
<span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">target_shape</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="c1"># Fixed sequence lengths currently.</span>
|
|
<span class="c1"># Support variable sequence lengths later.</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">))</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">medusa_fp_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
|
<span class="n">medusa_fp_mask</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span><span class="p">)]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_mask</span> <span class="o">=</span> <span class="n">medusa_fp_mask</span>
|
|
<span class="k">return</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_get_num_paged_blocks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="n">sink_token_length</span><span class="p">):</span>
|
|
<span class="n">bubble_len</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">bubble_len</span> <span class="o">+=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span> <span class="o">-</span>
|
|
<span class="n">sink_token_length</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
|
<span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">max_attention_window_size</span> <span class="o">+</span> <span class="n">bubble_len</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">)</span>
|
|
<span class="n">num_blocks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">*</span> <span class="n">max_blocks_per_seq</span>
|
|
|
|
<span class="k">return</span> <span class="n">num_blocks</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span>
|
|
|
|
<div class="viewcode-block" id="GenerationSession.setup">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.setup">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">max_attention_window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_max_input_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">lora_manager</span><span class="p">:</span> <span class="n">LoraManager</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">lora_uids</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">multi_block_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
|
<span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
|
<span class="c1"># Store these params related to buffer size to check against</span>
|
|
<span class="c1"># the input shape with the params given in decode()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span> <span class="o">=</span> <span class="n">max_context_length</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">=</span> <span class="n">max_new_tokens</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">=</span> <span class="n">max_context_length</span> <span class="o">+</span> <span class="n">max_new_tokens</span>
|
|
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span> <span class="o">=</span> <span class="n">beam_width</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span> <span class="o">=</span> <span class="n">encoder_max_input_length</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span> <span class="o">=</span> <span class="n">multi_block_mode</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span> <span class="o">=</span> <span class="n">enable_context_fmha_fp32_acc</span>
|
|
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
|
|
<span class="s2">"The max_attention_window_size is not set, we will use max_seq_length by default."</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
|
|
|
|
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="n">max_attention_window_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"The value of max_attention_window_size should ideally not exceed max_seq_length. "</span>
|
|
<span class="s2">"Therefore, it has been adjusted to match the value of max_seq_length."</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span>
|
|
|
|
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
|
|
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
|
<span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
|
<span class="n">max_attention_window_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">max_attention_window_size</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
|
|
<span class="n">attn_win_size_len</span> <span class="o">=</span> <span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">num_total_attn_layers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">'attention'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">attn_win_size_len</span> <span class="o"><</span> <span class="n">num_total_attn_layers</span><span class="p">:</span>
|
|
<span class="n">repeat_num</span> <span class="o">=</span> <span class="n">num_total_attn_layers</span> <span class="o">//</span> <span class="n">attn_win_size_len</span>
|
|
<span class="n">remain_num</span> <span class="o">=</span> <span class="n">num_total_attn_layers</span> <span class="o">%</span> <span class="n">attn_win_size_len</span>
|
|
<span class="n">warning_info</span> <span class="o">=</span> <span class="s2">"The size of max_attention_window_size tensor/list is less than num_attn_layers, "</span> \
|
|
<span class="o">+</span> <span class="s2">"and it will be repeated to num_attn_layers. So the actual max_attention_window_size "</span> \
|
|
<span class="o">+</span> <span class="sa">f</span><span class="s2">"is </span><span class="si">{</span><span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span><span class="si">}</span><span class="s2"> * </span><span class="si">{</span><span class="n">repeat_num</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="n">warning_info</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">" + </span><span class="si">{</span><span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">()[</span><span class="mi">0</span><span class="p">:</span><span class="n">remain_num</span><span class="p">]</span><span class="si">}</span><span class="s2">. "</span> <span class="k">if</span> <span class="n">remain_num</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="s2">". "</span>
|
|
<span class="n">warning_info</span> <span class="o">+=</span> <span class="s2">"Note that num_attn_layers is the number of total attention layers."</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">warning_info</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">attn_win_size_len</span> <span class="o">></span> <span class="n">num_total_attn_layers</span><span class="p">:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
|
|
<span class="s2">"The size of max_attention_window_size tensor/list is larger than num_attn_layers! "</span>
|
|
<span class="s2">"Note that num_attn_layers is the number of total attention layers."</span>
|
|
<span class="p">)</span>
|
|
<span class="k">assert</span> <span class="kc">False</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">:</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
|
<span class="s2">"The value of max_attention_window_size should ideally not exceed max_seq_length. "</span>
|
|
<span class="s2">"Therefore, it has been adjusted to match the value of max_seq_length."</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
|
<span class="n">max_attention_window_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span>
|
|
<span class="n">max_attention_window_size</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">]</span> <span class="o">*</span> <span class="n">attn_win_size_len</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">[</span>
|
|
<span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">max_attention_window_size</span><span class="p">[</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">]</span><span class="o">.</span><span class="n">count</span><span class="p">(</span><span class="s1">'attention'</span><span class="p">)</span>
|
|
<span class="o">+</span> <span class="n">i</span><span class="p">)</span> <span class="o">%</span> <span class="n">attn_win_size_len</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"invalid max_attention_window_size!"</span>
|
|
|
|
<span class="k">if</span> <span class="n">sink_token_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sink_token_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o">=</span> <span class="n">sink_token_length</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"invalid sink_token_length!"</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="o">=</span> <span class="n">lora_manager</span>
|
|
<span class="k">if</span> <span class="n">medusa_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_init_medusa</span><span class="p">(</span><span class="n">medusa_choices</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">init_allocate_redrafter_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">medusa_logits_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span> <span class="o">*</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">medusa_logits_shape</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'medusa_logits'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="k">else</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'logits'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="c1"># use shape info to pass max length info in remove padding mode</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="s1">'encoder_max_input_length'</span><span class="p">),</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
|
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
|
|
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">first_atten_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">]</span><span class="o">.</span><span class="n">index</span><span class="p">(</span>
|
|
<span class="s1">'attention'</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
|
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">first_atten_layer</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">num_blocks</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="p">(</span>
|
|
<span class="n">num_blocks</span><span class="o">=</span><span class="n">num_blocks</span><span class="p">,</span>
|
|
<span class="n">tokens_per_block</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="n">head_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">allocate</span><span class="p">(</span><span class="n">kv_cache_type</span><span class="p">,</span>
|
|
<span class="n">num_kv_heads_per_layer</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span> <span class="c1"># As for now we enable cross paged kv and self paged kv to share the same tokens_per_block</span>
|
|
<span class="n">cross_num_blocks</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="n">sink_token_length</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
|
|
<span class="n">num_kv_heads_per_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="p">(</span>
|
|
<span class="n">num_blocks</span><span class="o">=</span><span class="n">cross_num_blocks</span><span class="p">,</span>
|
|
<span class="n">tokens_per_block</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="n">head_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="o">=</span> <span class="n">MemoryPoolsAllocator</span><span class="o">.</span><span class="n">prepare_num_kv_heads_per_layer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">num_kv_heads_per_cross_attn_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_cross_attn_layer</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">allocate</span><span class="p">(</span>
|
|
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">num_kv_heads_per_cross_attn_layer</span><span class="p">)</span>
|
|
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span><span class="n">i</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">cache_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">cross_cache_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># Without plugin, we need extra kv cache buffers.</span>
|
|
<span class="c1"># Because we don't support inplace update, so we need separate buffer for inputs and outputs.</span>
|
|
<span class="c1"># We can do reuse between different layers' inputs and outputs, i.e. current layer's output can</span>
|
|
<span class="c1"># reuse previous layer's input memory. But this need one extra buffer as the guard.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span> <span class="c1"># Not applicable to cross KV buffers as it's constant</span>
|
|
<span class="n">i</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_dtype</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">trt_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">:</span>
|
|
<span class="c1"># PyTorch doesn't support fp8 datatype, use int8 instead of it because int8 datatype size is same with fp8.</span>
|
|
<span class="c1"># TODO: Remove this section when PyTorch support fp8 datatype</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">cache_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">rnn_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_hidden_size</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_head_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">rnn_state_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_hidden_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'recurrent'</span><span class="p">:</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">conv_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">rnn_state_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">state_dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
|
<span class="n">conv_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
|
<span class="n">rnn_state_ptr</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">()],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">conv_state_ptr</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">rnn_state_ptr</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">lora_uids</span> <span class="o">=</span> <span class="n">lora_uids</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">"-1"</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">lora_manager</span><span class="o">.</span><span class="n">input_buffers</span><span class="p">(</span>
|
|
<span class="n">lora_uids</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span>
|
|
<span class="p">))</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
|
|
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
|
<span class="n">M</span> <span class="o">=</span> <span class="n">max_num_tokens</span>
|
|
<span class="n">N</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span> <span class="o">=</span> <span class="n">M</span> <span class="o">*</span> <span class="n">N</span>
|
|
<span class="n">itemsize</span> <span class="o">=</span> <span class="n">str_dtype_to_torch</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span><span class="o">.</span><span class="n">itemsize</span>
|
|
<span class="n">alloc_bytes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span> <span class="o">*</span> <span class="n">itemsize</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span> <span class="o">=</span> <span class="n">ipc_nvls_allocate</span><span class="p">(</span>
|
|
<span class="n">alloc_bytes</span><span class="p">,</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_group</span><span class="p">))</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Allocated NVLS IPC memory: </span><span class="si">{</span><span class="n">alloc_bytes</span><span class="si">}</span><span class="s1"> bytes'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'spec_decoding_packed_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_packed_mask</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_position_offsets</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_generation_lengths</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spec_decoding_use</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span> <span class="o">=</span> <span class="kc">True</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span></div>
|
|
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_allocate_empty_kv_cache_pools</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">num_blocks</span><span class="p">):</span>
|
|
<span class="c1"># Layers are homogeneous, use old kv cache shape</span>
|
|
<span class="n">unique_cache_pools</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">num_blocks</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span>
|
|
<span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="n">unique_cache_pools</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cache_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
|
|
|
|
<span class="c1"># Layers are not homogeneous, use new kv cache shape</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">kv_heads_unique_counter</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">num_kv_heads_per_layer</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">kv_head</span><span class="p">,</span> <span class="n">num_layers</span> <span class="ow">in</span> <span class="n">kv_heads_unique_counter</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">num_blocks</span><span class="p">,</span>
|
|
<span class="n">num_layers</span><span class="p">,</span>
|
|
<span class="mi">2</span><span class="p">,</span>
|
|
<span class="n">kv_head</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="n">unique_cache_pools</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cache_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
|
|
|
|
<span class="k">return</span> <span class="n">unique_cache_pools</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_get_context_shape_buffer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">language_adapter_routings</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">]:</span>
|
|
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_from_pointer</span><span class="p">(</span><span class="n">pointer</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
|
|
<span class="n">name</span><span class="p">:</span>
|
|
<span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">)</span>
|
|
<span class="p">})</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
|
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_bs</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">bs</span><span class="p">):</span>
|
|
<span class="c1"># this assumes dim0 to be bs and only overrides dim0 with given bs</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bs</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
|
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="s1">'context_lengths'</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="n">host_runtime_perf_knobs</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"gpt_attention_plugin needs to set host_runtime_perf_knobs"</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_runtime_perf_knobs</span><span class="p">,</span> <span class="s1">'host_runtime_perf_knobs'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_progress</span><span class="p">,</span> <span class="s1">'host_context_progress'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">'cache_indirection'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">'position_ids'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="c1"># in context phase, need to generate cross kv cache, set to True</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
|
<span class="s1">'cross_kv_cache_gen'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span> <span class="s1">'skip_cross_attn_blocks'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># see Attention's self.qkv output dim</span>
|
|
<span class="n">cross_kv_out_dim</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span>
|
|
<span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span>
|
|
<span class="n">cross_kv_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span>
|
|
<span class="n">cross_kv_out_dim</span><span class="p">,</span> <span class="p">)</span>
|
|
<span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">cross_kv_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span> <span class="o">=</span> <span class="n">cross_kv_reuse</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span><span class="p">,</span> <span class="s1">'cross_kv_reuse'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">'encoder_output'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">'encoder_input_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">language_adapter_routings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">language_adapter_routings</span><span class="p">,</span>
|
|
<span class="s1">'language_adapter_routings'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">],</span>
|
|
<span class="s1">'encoder_max_input_length'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># cross-attention packed mask (used by fmha).</span>
|
|
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">pack_fmha_mask_by_input</span><span class="p">(</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
|
|
<span class="s1">'cross_attention_packed_mask'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># create a full 1 cross_attention_mask because it is necessary</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">cross_attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(),</span>
|
|
<span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">prod</span><span class="p">()),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s2">"cross_attention_mask"</span><span class="p">)</span>
|
|
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensorrt_llm</span><span class="o">.</span><span class="n">pack_fmha_mask_by_input</span><span class="p">(</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
|
|
<span class="s2">"cross_attention_packed_mask"</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
|
<span class="k">if</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
|
|
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span>
|
|
<span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">hidden_size</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">set_redrafter_ctx_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">add_tensor</span><span class="p">,</span> <span class="n">add_tensor_with_bs</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">],</span> <span class="s1">'medusa_logits'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_output'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_input'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">'prompt_embedding_table'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">([</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()],</span>
|
|
<span class="n">tasks</span><span class="p">[</span><span class="n">b</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
|
<span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">tasks_generation</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">tasks_generation</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">buffer</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">buffer</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'kv_cache_block_offsets'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_kv_cache_block_offsets'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_kv_cache_pool_pointers'</span>
|
|
<span class="n">pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_kv_cache_pool_mapping'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_pointers</span><span class="p">],</span> <span class="n">pool_pointers</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_mapping</span><span class="p">],</span> <span class="n">pool_mapping</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_buffer</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
|
<span class="n">cross_shape</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">cross_shape</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">cross_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
|
|
<span class="o">*</span><span class="n">cross_shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]</span>
|
|
<span class="p">]</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_kv_cache_block_offsets'</span><span class="p">,</span>
|
|
<span class="n">cross_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_cross_kv_cache_block_offsets'</span><span class="p">,</span>
|
|
<span class="n">cross_shape</span><span class="p">)</span>
|
|
<span class="n">cross_pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_pointers'</span>
|
|
<span class="n">cross_pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_mapping'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_pointers</span><span class="p">],</span>
|
|
<span class="n">cross_pool_pointers</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_mapping</span><span class="p">],</span> <span class="n">cross_pool_mapping</span><span class="p">)</span>
|
|
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_kv_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span>
|
|
<span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="n">kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">general_to_attn_idx</span><span class="p">[</span><span class="n">idx</span><span class="p">]),</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
|
|
<span class="n">kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="n">kv_cache_shape</span><span class="p">)</span>
|
|
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_kv_cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
<span class="c1"># for empty tensor, TRT does not really use the tensor data, so any dtype is fine</span>
|
|
<span class="n">cross_kv_cache_buffer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_shape</span><span class="p">)</span>
|
|
<span class="n">cross_present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_present</span><span class="p">],</span> <span class="n">cross_present</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">==</span> <span class="s1">'attention'</span><span class="p">:</span>
|
|
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="c1"># when plugin is used, past_ket_value tensor does not need to be empty tensor</span>
|
|
<span class="c1"># because plugin does not care, and does not use this shape.</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">!=</span> <span class="s1">'recurrent'</span><span class="p">:</span>
|
|
<span class="k">continue</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># conv state</span>
|
|
<span class="n">dtype</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
|
|
|
<span class="n">conv_state</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">conv_state_shape</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">conv_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">present</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">present</span><span class="p">],</span> <span class="n">present</span><span class="p">)</span>
|
|
<span class="c1"># rnn state</span>
|
|
<span class="n">rnn_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">'slot_mapping'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="c1"># context request</span>
|
|
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">device_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">device_request_types</span><span class="p">,</span> <span class="s1">'device_request_types'</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
|
|
<span class="s1">'sequence_length'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
|
|
|
|
<span class="c1"># field 0: past_key_value_length, field 1: is_context (deprecated). changed to [0], otherwise affects batch padded input mode</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span>
|
|
<span class="s1">'host_past_key_value_lengths'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
|
|
<span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">'attention_mask'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
|
|
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_uc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">uc_ptr</span><span class="p">,</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_mc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">mc_ptr</span><span class="p">,</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_ipc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">get_ipc_ptrs</span><span class="p">(),</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
|
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
|
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
|
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
|
|
<span class="n">lora_weights</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_weights</span><span class="p">],</span> <span class="n">lora_weights</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cpu'</span><span class="p">),</span>
|
|
<span class="s1">'host_encoder_input_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="c1"># Medusa mask and position offsets are fixed for the whole session.</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_packed_mask'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_packed_mask'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_position_offsets'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_generation_lengths'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">],</span> <span class="s1">'spec_decoding_use'</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">tensors</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_get_next_step_shape_buffer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">hidden_states_input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">skip_cross_attn_blocks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">language_adapter_routings</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"_get_next_step_shape_buffer"</span><span class="p">)</span>
|
|
<span class="n">tensors</span> <span class="o">=</span> <span class="p">{}</span> <span class="c1"># Dict[str, RuntimeTensor]</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_from_pointer</span><span class="p">(</span><span class="n">pointer</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span>
|
|
<span class="n">name</span><span class="p">:</span>
|
|
<span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_pointer</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">pointer</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">str_dtype</span><span class="p">)</span>
|
|
<span class="p">})</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="n">name</span><span class="p">:</span> <span class="n">sym</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">)})</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">add_tensor_with_shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
|
<span class="k">return</span> <span class="n">tensors</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
|
<span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">RuntimeTensor</span><span class="o">.</span><span class="n">from_torch</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">override_shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)})</span>
|
|
|
|
<span class="n">context_lengths_local</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="n">host_context_lengths_local</span> <span class="o">=</span> <span class="n">host_context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">context_lengths_local</span><span class="p">,</span> <span class="s1">'context_lengths'</span><span class="p">)</span>
|
|
<span class="k">assert</span> <span class="n">host_runtime_perf_knobs</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"gpt_attention_plugin needs to set host_runtime_perf_knobs"</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_runtime_perf_knobs</span><span class="p">,</span> <span class="s1">'host_runtime_perf_knobs'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_progress</span><span class="p">,</span> <span class="s1">'host_context_progress'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cache_indirection</span><span class="p">,</span> <span class="s1">'cache_indirection'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="s1">'position_ids'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">hidden_size</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
|
|
<span class="n">hidden_states_input</span> <span class="o">=</span> <span class="n">hidden_states_input</span><span class="o">.</span><span class="n">resize_</span><span class="p">(</span><span class="o">*</span><span class="n">shape</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="s1">'logits'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">],</span> <span class="s1">'medusa_logits'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="s1">'last_token_ids'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_output'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">,</span> <span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">input_ids_shape</span> <span class="o">=</span> <span class="p">(</span>
|
|
<span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="k">else</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'flat_tokens'</span><span class="p">],</span> <span class="s1">'input_ids'</span><span class="p">,</span>
|
|
<span class="n">input_ids_shape</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span>
|
|
<span class="n">input_ids_shape</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="s1">'input_ids'</span><span class="p">,</span>
|
|
<span class="n">input_ids_shape</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">hidden_states_input</span><span class="p">,</span> <span class="s1">'hidden_states_input'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="c1"># disable (or minimize) cross qkv computation at generation phase</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_cross_kv</span><span class="p">:</span>
|
|
<span class="c1"># disable</span>
|
|
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_reuse</span><span class="p">,</span> <span class="s1">'cross_kv_reuse'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># minimize</span>
|
|
<span class="c1"># use TensorRT Empty Tensor to skip redundant computation</span>
|
|
<span class="c1"># 0 for generation phase, >0 for context phase</span>
|
|
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">encoder_output_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># OOTB path doesn't have kv cache for now, so this encoder_output is</span>
|
|
<span class="c1"># a must-have input. We just use the encoder_output</span>
|
|
<span class="n">encoder_output_shape</span> <span class="o">=</span> <span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span>
|
|
|
|
<span class="c1"># in generation phase, cross kv cache is already filled during context phase, set to False</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
|
<span class="s1">'cross_kv_cache_gen'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_config</span><span class="o">.</span><span class="n">skip_cross_attn_blocks</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span> <span class="s1">'skip_cross_attn_blocks'</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">encoder_output</span><span class="p">,</span> <span class="s1">'encoder_output'</span><span class="p">,</span>
|
|
<span class="n">encoder_output_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="s1">'encoder_input_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">language_adapter_routings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">language_adapter_routings</span><span class="p">,</span>
|
|
<span class="s1">'language_adapter_routings'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'encoder_max_input_length'</span><span class="p">],</span>
|
|
<span class="s1">'encoder_max_input_length'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="o">!=</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">cross_attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="c1"># Empty packed mask is passed in the generation phase as it is not used.</span>
|
|
<span class="n">cross_attention_packed_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="n">cross_attention_mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">31</span><span class="p">)</span> <span class="o">//</span> <span class="mi">32</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_mask</span><span class="p">,</span> <span class="s1">'cross_attention_mask'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_attention_packed_mask</span><span class="p">,</span>
|
|
<span class="s1">'cross_attention_packed_mask'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># create a full 1 cross_attention_mask because it is necessary in generation phase</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span>
|
|
<span class="n">encoder_output</span><span class="o">.</span><span class="n">shape</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">prod</span><span class="p">()),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="s2">"cross_attention_mask"</span><span class="p">)</span>
|
|
<span class="c1"># Empty packed mask is passed in the generation phase as it is not used.</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
|
<span class="s2">"cross_attention_packed_mask"</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="n">kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="o">*</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]]</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'kv_cache_block_offsets'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_kv_cache_block_offsets'</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_kv_cache_pool_pointers'</span>
|
|
<span class="n">pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_kv_cache_pool_mapping'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_pointers</span><span class="p">],</span> <span class="n">pool_pointers</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">pool_mapping</span><span class="p">],</span> <span class="n">pool_mapping</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_shape</span> <span class="o">=</span> <span class="n">cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">cross_shape</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">cross_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cross_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span>
|
|
<span class="o">*</span><span class="n">cross_shape</span><span class="p">[</span><span class="mi">3</span><span class="p">:]</span>
|
|
<span class="p">]</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_kv_cache_block_offsets'</span><span class="p">,</span>
|
|
<span class="n">cross_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_cross_kv_cache_block_offsets'</span><span class="p">,</span>
|
|
<span class="n">cross_shape</span><span class="p">)</span>
|
|
<span class="n">cross_pool_pointers</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_pointers'</span>
|
|
<span class="n">cross_pool_mapping</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_mapping'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_pointers</span><span class="p">],</span>
|
|
<span class="n">cross_pool_pointers</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">cross_pool_mapping</span><span class="p">],</span> <span class="n">cross_pool_mapping</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">prompt_embedding_table</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_embedding_table</span><span class="p">,</span> <span class="s1">'prompt_embedding_table'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">gen_tasks</span> <span class="o">=</span> <span class="n">tasks</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">gen_tasks</span><span class="p">,</span> <span class="s1">'tasks'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">prompt_vocab_size</span><span class="p">,</span> <span class="s1">'prompt_vocab_size'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">attn_idx</span><span class="p">,</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">next_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
|
|
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
<span class="c1"># We will make current layer's output KV-cache overwrite previous layers input KV-cache</span>
|
|
<span class="c1"># buffer id: ... 5, 6, 7, 8, 9, ...</span>
|
|
<span class="c1"># layer n: out in</span>
|
|
<span class="c1"># layer n+1: out in</span>
|
|
<span class="c1"># layer n+2 out in</span>
|
|
<span class="c1"># And when finish a step, we will make every layer's in/out buffer index subtract 1 in</span>
|
|
<span class="c1"># a circular buffer way to make sure current outputs become next step's inputs.</span>
|
|
<span class="n">num_buffers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">+</span> <span class="mi">1</span>
|
|
<span class="n">input_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">attn_idx</span> <span class="o">-</span> <span class="p">(</span><span class="n">step</span> <span class="o">%</span> <span class="n">num_buffers</span><span class="p">))</span> <span class="o">%</span> <span class="n">num_buffers</span>
|
|
<span class="n">output_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_idx</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">num_buffers</span>
|
|
<span class="n">input_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">input_idx</span><span class="p">]</span>
|
|
<span class="n">output_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">output_idx</span><span class="p">]</span>
|
|
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">input_name</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span>
|
|
<span class="n">next_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">output_name</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">key_value_cache</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">key_value_cache</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_cache_buffer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_past_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">cross_cache_buffer</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'cross_present_key_value_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_types</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">!=</span> <span class="s1">'recurrent'</span><span class="p">:</span>
|
|
<span class="k">continue</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'conv_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'rnn_state_ptr_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># conv state</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mamba_conv1d_plugin</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">conv_state_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn_conv_dim_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">conv_kernel</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'past_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">,</span> <span class="n">conv_state_shape</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'1_present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">],</span>
|
|
<span class="sa">f</span><span class="s1">'present_conv_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="c1"># rnn state</span>
|
|
<span class="n">rnn_state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'past_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">rnn_state</span><span class="p">,</span> <span class="sa">f</span><span class="s1">'present_rnn_state_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_state</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">slot_mapping</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">slot_mapping</span><span class="p">,</span> <span class="s1">'slot_mapping'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="c1"># generation requests</span>
|
|
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"device_request_types"</span><span class="p">)</span>
|
|
<span class="n">device_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">device_request_types</span><span class="p">,</span> <span class="s1">'device_request_types'</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># previous [past_kv_length, is_context] has been deprecated. only past_kv_length should be given here</span>
|
|
<span class="c1"># Note we should use max_context_length here to align to max -- but isn't this done in attn plugin's max_element() already?</span>
|
|
<span class="n">host_past_key_value_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
|
<span class="s1">'host_past_key_value_lengths'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
|
<span class="c1"># Sequence lengths are not used in the context phase actually.</span>
|
|
<span class="n">sequence_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span>
|
|
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s1">'sequence_length'</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_sink_token_length</span><span class="p">,</span>
|
|
<span class="s1">'host_sink_token_length'</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
|
<span class="sa">f</span><span class="s1">'host_max_attention_window_sizes'</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span><span class="p">,</span> <span class="p">))</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span> <span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">:</span>
|
|
<span class="n">host_request_types</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cpu'</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_request_types</span><span class="p">,</span> <span class="s1">'host_request_types'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">host_context_lengths_local</span><span class="p">,</span>
|
|
<span class="s1">'host_context_lengths'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="s1">'attention_mask'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_reduce_workspace</span><span class="p">,</span> <span class="s1">'all_reduce_workspace'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gemm_allreduce_plugin</span><span class="p">:</span>
|
|
<span class="n">found_tensor_names</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">get_tensor_name</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">num_io_tensors</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">found_tensor_names</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_uc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">uc_ptr</span><span class="p">,</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_mc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">mc_ptr</span><span class="p">,</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">"gemm_allreduce_ipc_out"</span><span class="p">):</span>
|
|
<span class="n">add_tensor_from_pointer</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_handle</span><span class="o">.</span><span class="n">get_ipc_ptrs</span><span class="p">(),</span>
|
|
<span class="n">name</span><span class="p">,</span>
|
|
<span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_output_size</span><span class="p">),</span>
|
|
<span class="n">str_dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Since we are using a ping-pong context design and the lora weight remains constant within the same request,</span>
|
|
<span class="c1"># it is only necessary to set the lora weight for the first two steps.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_lora_plugin</span> <span class="ow">and</span> <span class="n">step</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
|
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">idx</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span>
|
|
<span class="k">for</span> <span class="n">lora_module</span> <span class="ow">in</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">+</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">missing_qkv_modules</span><span class="p">):</span>
|
|
<span class="n">lora_ranks</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_ranks_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_ranks</span><span class="p">],</span> <span class="n">lora_ranks</span><span class="p">)</span>
|
|
<span class="n">lora_module</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">'</span><span class="si">{</span><span class="n">lora_module</span><span class="si">}</span><span class="s1">_lora_weights_pointers_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">lora_module</span><span class="p">],</span> <span class="n">lora_module</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="n">encoder_input_lengths</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cpu'</span><span class="p">),</span>
|
|
<span class="s1">'host_encoder_input_lengths'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="c1"># Spec Decoding mask and position offsets are fixed for the whole session for Medusa.</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_packed_mask'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_packed_mask'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_position_offsets'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_position_offsets'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_generation_lengths'</span><span class="p">],</span>
|
|
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">)</span>
|
|
<span class="n">add_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_use'</span><span class="p">],</span> <span class="s1">'spec_decoding_use'</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">set_redrafter_gen_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">add_tensor</span><span class="p">,</span>
|
|
<span class="n">add_tensor_with_shape</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="k">return</span> <span class="n">tensors</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
|
|
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="c1"># For Medusa, last_token_ids should contain the actual indices</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span> <span class="c1"># sub 1 from context_lengths for indices</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="n">use_gpt_attention_plugin</span>
|
|
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">)</span> <span class="ow">and</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'max_context_length'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="p">])</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="mi">1</span><span class="p">,</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
|
|
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
|
|
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span>
|
|
<span class="n">perf_knob_tensor_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">:</span>
|
|
<span class="n">context_runtime_perf_knobs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># multi_block_mode</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span>
|
|
<span class="n">context_runtime_perf_knobs</span><span class="p">[</span>
|
|
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># enable_context_fmha_fp32_acc</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'host_runtime_perf_knobs'</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_runtime_perf_knobs</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'input_ids'</span><span class="p">)</span>
|
|
<span class="n">pad_id</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'pad_id'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_prepare_attention_mask</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">pad_id</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
|
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'position_ids_base'</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="c1"># NOTE: Generate random tensors using torch</span>
|
|
<span class="n">redrafter_prepare_random_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">initialize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">ret</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"_prepare_generation_inputs"</span><span class="p">)</span>
|
|
|
|
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'step'</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span> <span class="ow">and</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span>
|
|
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="c1"># For Medusa, last_token_ids should be [bs * seq] and should contain the actual indices (starts from 1)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"last_token_ids_1s"</span><span class="p">)</span>
|
|
<span class="c1"># update last_token_ids here (buffers already swapped)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># For Medusa, last_token_ids should be [bs, seq] and should contain the actual indices (starts from 0)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">context_lengths</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="n">use_gpt_attention_plugin</span>
|
|
<span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span><span class="p">)</span> <span class="ow">and</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"last_token_ids_cumsum"</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="n">ret</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">}</span>
|
|
|
|
<span class="k">if</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"position_ids_update"</span><span class="p">)</span>
|
|
<span class="c1"># set position_ids</span>
|
|
<span class="c1"># buffers are swapped but sequence_length is not updated at this point</span>
|
|
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'position_ids_base'</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'num_accepted_tokens'</span><span class="p">]</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'packed_position_ids'</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="bp">self</span><span class="o">.</span><span class="n">host_total_gen_token</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">-=</span> <span class="mi">1</span>
|
|
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">context_lengths</span> <span class="o">+</span> <span class="n">step</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
|
|
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
|
|
<span class="n">gen_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">perf_knob_tensor_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">:</span>
|
|
<span class="n">gen_runtime_perf_knobs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># multi_block_mode</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span>
|
|
<span class="n">gen_runtime_perf_knobs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># enable_context_fmha_fp32_acc</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'host_runtime_perf_knobs'</span><span class="p">]</span> <span class="o">=</span> <span class="n">gen_runtime_perf_knobs</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">)</span>
|
|
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'num_beams'</span><span class="p">)</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">.</span><span class="n">new_ones</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">,</span> <span class="mi">1</span><span class="p">))),</span>
|
|
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">attention_mask</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
|
<span class="n">position_ids</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">attention_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_position_embedding</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">ret</span><span class="p">[</span><span class="s1">'position_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">position_ids</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="c1"># buffers are already swapped</span>
|
|
<span class="c1"># convert spec_decoding_mask to spec_decoding_packed_mask</span>
|
|
<span class="n">redrafter_convert_spec_decoding_mask_to_packed_mask</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'spec_decoding_generation_lengths'</span><span class="p">])</span>
|
|
<span class="c1"># NOTE: Generate random tensors using torch</span>
|
|
<span class="n">redrafter_prepare_random_tensors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="k">return</span> <span class="n">ret</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_cross_attention_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">):</span>
|
|
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="n">max_decoder_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">decoder_input_length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">cross_attention_mask</span><span class="p">[</span>
|
|
<span class="n">batch_idx</span><span class="p">][:</span><span class="n">decoder_input_length</span><span class="p">,</span> <span class="p">:]</span>
|
|
<span class="n">local_mask_for_gen</span> <span class="o">=</span> <span class="n">cross_attention_mask</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">][</span>
|
|
<span class="n">decoder_input_length</span><span class="p">:,</span> <span class="p">:]</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">local_mask_for_context</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">local_mask_for_context</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span>
|
|
<span class="n">local_mask_for_context</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="n">max_decoder_input_length</span> <span class="o">-</span> <span class="n">decoder_input_length</span><span class="p">)),</span>
|
|
<span class="s2">"constant"</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">local_mask_for_gen</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span>
|
|
<span class="n">local_mask_for_gen</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="p">(</span><span class="n">max_decoder_input_length</span> <span class="o">-</span> <span class="n">decoder_input_length</span><span class="p">)),</span>
|
|
<span class="s2">"constant"</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">local_mask_for_context</span><span class="p">)</span>
|
|
<span class="c1"># add additional dimension for batch size.</span>
|
|
<span class="n">cross_attention_mask_for_gen</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">local_mask_for_gen</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
|
|
|
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">cross_attention_mask_for_context</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span>
|
|
<span class="n">cross_attention_mask_for_gen</span><span class="p">)</span>
|
|
|
|
<div class="viewcode-block" id="GenerationSession.pp_communicate_new_tokens">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_new_tokens">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">pp_communicate_new_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">cache_indir</span><span class="p">,</span>
|
|
<span class="n">sequence_length</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">for</span> <span class="n">pg</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">pg</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">:</span>
|
|
<span class="k">continue</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">pg</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">should_stop</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">cache_indir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">return</span> <span class="n">should_stop</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.pp_communicate_final_output_ids">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.pp_communicate_final_output_ids">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">pp_communicate_final_output_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">send</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">nccl_comm</span><span class="o">.</span><span class="n">recv</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_group</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
|
<span class="k">return</span> <span class="n">final_output_ids</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.finalize_decoder">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.finalize_decoder">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">finalize_decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">in_progress</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="c1"># output shape of self.gather_tree: [batch_size, beam_width, output_len]</span>
|
|
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span> <span class="ow">and</span> <span class="n">in_progress</span><span class="p">:</span>
|
|
<span class="c1"># self.gather_tree modifies these args.</span>
|
|
<span class="c1"># In streaming mode, this results in incorrect decoding in the following steps.</span>
|
|
<span class="n">beam_hyps_args</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">beam_hyps_args</span><span class="p">)</span>
|
|
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_tree</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="o">*</span><span class="n">beam_hyps_args</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">length_penalty</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span> <span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Communicate ranks in Pipeline Parallelism</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_final_output_ids</span><span class="p">(</span>
|
|
<span class="n">final_output_ids</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">final_output_ids</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.find_best_medusa_path">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.find_best_medusa_path">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">find_best_medusa_path</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">next_logits</span><span class="p">,</span>
|
|
<span class="n">temp</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
|
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span>
|
|
<span class="n">best_path</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
|
<span class="n">best_path_len</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
|
<span class="n">next_tokens</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span>
|
|
<span class="n">zero_pad</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">temp</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">new_tokens_raw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span>
|
|
<span class="n">next_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span>
|
|
<span class="p">)</span> <span class="c1"># TODO: can be done by treating [bs, nT, vocab] as [bs*nT, vocab] and using decoderOp?</span>
|
|
<span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">new_tokens_raw</span><span class="p">,</span> <span class="n">zero_pad</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">input_paths</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="n">new_paths</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">new_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">equality</span> <span class="o">=</span> <span class="n">input_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
|
<span class="n">paths_correct_len</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">equality</span><span class="o">.</span><span class="n">int</span><span class="p">(),</span>
|
|
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">paths_correct_len</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
|
|
<span class="k">if</span> <span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">paths_correct_len</span><span class="p">)</span>
|
|
<span class="n">next_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_paths</span><span class="p">[</span><span class="n">b</span><span class="p">][</span>
|
|
<span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">]][:</span><span class="n">best_path_len</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
|
|
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span> <span class="n">next_tokens</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.filter_medusa_logits">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.filter_medusa_logits">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">filter_medusa_logits</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span>
|
|
<span class="n">medusa_logits</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> medusa_logits is of shape [nMH, bs, nMT+1, vocab]</span>
|
|
|
|
<span class="sd"> Returns [nMH, bs, vocab]</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">filtered_logits</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span><span class="p">[</span><span class="n">best_path</span><span class="p">[</span><span class="n">b</span><span class="p">],</span> <span class="n">best_path_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
|
|
<span class="n">filtered_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="p">[:,</span> <span class="n">b</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
|
|
<span class="k">return</span> <span class="n">filtered_logits</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.get_next_medusa_tokens">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.get_next_medusa_tokens">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">get_next_medusa_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">):</span>
|
|
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="n">next_medusa_logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="p">]</span> <span class="c1"># dummy token for now, TODO: update tree_ids and remove this</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">):</span>
|
|
<span class="n">medusa_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">next_medusa_logits</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:],</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_topks</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
|
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">indices</span>
|
|
<span class="n">next_medusa_tokens</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">medusa_token</span><span class="p">)</span>
|
|
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">next_medusa_tokens</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">next_medusa_tokens</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.locate_accepted_draft_tokens">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.locate_accepted_draft_tokens">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">locate_accepted_draft_tokens</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_len</span><span class="p">,</span>
|
|
<span class="n">draft_paths</span><span class="p">):</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"locate_accepted_draft_tokens"</span><span class="p">)</span>
|
|
<span class="n">best_path_len_tensor</span> <span class="o">=</span> <span class="n">best_path_len</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
|
<span class="n">best_path_len</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span>
|
|
<span class="n">best_path_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">accepted_draft_token_counts</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span>
|
|
<span class="n">best_path_len_tensor</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">best_path_len_tensor</span><span class="o">.</span><span class="n">device</span><span class="p">))</span>
|
|
<span class="n">accepted_draft_token_offsets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">accepted_draft_token_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span>
|
|
<span class="n">accepted_draft_token_counts</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">accepted_draft_token_offsets_cpu</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="s1">'cpu'</span><span class="p">)</span>
|
|
<span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span>
|
|
<span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">seq_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">cur_draft_paths</span> <span class="o">=</span> <span class="n">draft_paths</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="k">else</span> <span class="n">draft_paths</span><span class="p">[</span>
|
|
<span class="n">seq_idx</span><span class="p">]</span>
|
|
<span class="n">seq_start</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
|
|
<span class="n">seq_end</span> <span class="o">=</span> <span class="n">accepted_draft_token_offsets_cpu</span><span class="p">[</span><span class="n">seq_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
|
|
<span class="n">seq_accepted_draft_count</span> <span class="o">=</span> <span class="n">seq_end</span> <span class="o">-</span> <span class="n">seq_start</span>
|
|
<span class="n">best_path_idx</span> <span class="o">=</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
|
<span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">best_path</span><span class="p">[</span><span class="n">seq_idx</span><span class="p">]</span>
|
|
<span class="n">seq_accepted_token_indices</span> <span class="o">=</span> <span class="n">cur_draft_paths</span><span class="p">[</span>
|
|
<span class="n">best_path_idx</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">1</span> <span class="o">+</span> <span class="n">seq_accepted_draft_count</span><span class="p">]</span>
|
|
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">[</span>
|
|
<span class="n">seq_start</span><span class="p">:</span><span class="n">seq_end</span><span class="p">]</span> <span class="o">=</span> <span class="n">seq_accepted_token_indices</span> <span class="o">-</span> <span class="mi">1</span>
|
|
<span class="c1"># print("KV offsets & indices", accepted_draft_token_offsets,</span>
|
|
<span class="c1"># packed_accepted_draft_tokens_indices,)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="k">return</span> <span class="n">accepted_draft_token_offsets</span><span class="p">,</span> <span class="n">packed_accepted_draft_tokens_indices</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.update_output_ids_by_offset">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.update_output_ids_by_offset">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">update_output_ids_by_offset</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_generated_ids</span><span class="p">,</span> <span class="n">offsets</span><span class="p">):</span>
|
|
<span class="c1"># output_ids [batch_size, padded_input_length]</span>
|
|
<span class="c1"># new_generated_ids [batch_size, padded_accepted_length]</span>
|
|
<span class="c1"># offsets [batch_size]</span>
|
|
<span class="c1"># FIXME: using fused kernel to update the padded output ids.</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]:(</span>
|
|
<span class="n">offsets</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
<span class="p">)]</span> <span class="o">=</span> <span class="n">new_generated_ids</span><span class="p">[</span><span class="n">b</span><span class="p">][:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span>
|
|
<span class="k">return</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.next_medusa_input_ids">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.next_medusa_input_ids">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">next_medusa_input_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
|
<span class="c1"># self.new_tokens [batch_size, padded_accepted_length]</span>
|
|
<span class="c1"># self.accept_lengths [batch_size]</span>
|
|
<span class="c1"># self.medusa_new_tokens [batch_size, num_draft_tokens]</span>
|
|
<span class="c1"># FIXME: using fused kernel to generate the new medusa input ids.</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
|
|
<span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="p">:]</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.reorder_kv_cache_for_beam_search">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.reorder_kv_cache_for_beam_search">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">reorder_kv_cache_for_beam_search</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="c1"># Do nothing.</span>
|
|
<span class="k">return</span>
|
|
|
|
<span class="c1"># WAR: This degrades the latency performance in beam search</span>
|
|
<span class="c1"># due to memcpy. Recommend to use gpt attention plugin instead.</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="n">cache_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span>
|
|
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">)</span>
|
|
|
|
<span class="kn">import</span><span class="w"> </span><span class="nn">functools</span>
|
|
<span class="n">numel</span> <span class="o">=</span> <span class="n">functools</span><span class="o">.</span><span class="n">reduce</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">cache_shape</span><span class="p">)</span>
|
|
|
|
<span class="c1"># attention layer num + 1 extra buffer.</span>
|
|
<span class="n">num_buffers</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_attn_layers</span> <span class="o">+</span> <span class="mi">1</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_to_general_idx</span><span class="p">:</span>
|
|
<span class="c1"># Cyclic buffers, an output becomes the next step's input.</span>
|
|
<span class="n">input_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">-</span> <span class="p">(</span><span class="n">step</span> <span class="o">%</span> <span class="n">num_buffers</span><span class="p">))</span> <span class="o">%</span> <span class="n">num_buffers</span>
|
|
<span class="n">presents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_buffer_names</span><span class="p">[</span><span class="n">input_idx</span><span class="p">]]</span>
|
|
<span class="n">presents</span> <span class="o">=</span> <span class="n">presents</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">numel</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">*</span><span class="n">cache_shape</span><span class="p">)</span>
|
|
<span class="c1"># parent_ids = (batch, beam, max_seq_len)</span>
|
|
<span class="n">parent_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">[</span><span class="o">...</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">batch_beam</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">):</span>
|
|
<span class="n">batch</span> <span class="o">=</span> <span class="n">batch_beam</span> <span class="o">//</span> <span class="n">beam_width</span>
|
|
<span class="k">if</span> <span class="n">parent_ids</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">]</span> <span class="o">!=</span> <span class="n">batch_beam</span> <span class="o">%</span> <span class="n">beam_width</span><span class="p">:</span>
|
|
<span class="c1"># Update past kv cache to parent beam's cache.</span>
|
|
<span class="n">src_bbid</span> <span class="o">=</span> <span class="n">batch</span> <span class="o">*</span> <span class="n">beam_width</span> <span class="o">+</span> <span class="n">parent_ids</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">]</span>
|
|
<span class="n">presents</span><span class="p">[</span><span class="n">batch_beam</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="n">presents</span><span class="p">[</span><span class="n">src_bbid</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span></div>
|
|
|
|
|
|
<span class="c1"># OPTIMIZE: need to optimize this early-stop workflow.</span>
|
|
<div class="viewcode-block" id="GenerationSession.early_stop_criteria">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.early_stop_criteria">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">early_stop_criteria</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">should_stop</span><span class="p">):</span>
|
|
<span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">continue</span>
|
|
<span class="c1"># output sequence length criteria.</span>
|
|
<span class="n">prev_total_output_length</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
<span class="c1"># end id criteria.</span>
|
|
<span class="n">end_id_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">[</span>
|
|
<span class="n">b</span><span class="p">,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]]</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
<span class="n">should_stop_with_end_id</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">end_id_mask</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span>
|
|
<span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
<span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">)</span> <span class="ow">or</span> <span class="n">should_stop_with_end_id</span>
|
|
<span class="c1"># update accept lengths for the current step.</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="n">prev_total_output_length</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
<span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="n">prev_total_output_length</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="n">should_stop_with_end_id</span><span class="p">:</span>
|
|
<span class="c1"># get the position of first end_id.</span>
|
|
<span class="n">end_id_pos</span> <span class="o">=</span> <span class="p">(</span><span class="n">end_id_mask</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">end_id_pos</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">])</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
|
|
|
|
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span>
|
|
<span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="n">torch</span><span class="o">.</span><span class="n">all</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="n">should_stop</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.medusa_decode_and_verify">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.medusa_decode_and_verify">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">medusa_decode_and_verify</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
|
|
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'medusa_logits'</span><span class="p">]</span>
|
|
<span class="n">best_path</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">best_path_lengths</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="c1"># logits buffer is of shape [bs, medusa_tokens+1, vocab]</span>
|
|
<span class="c1"># but during context phase, we get only [bs, 1, vocab] but contiguous</span>
|
|
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)[:</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">next_main_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
|
<span class="n">next_main_token</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">next_main_token_logits</span><span class="p">,</span>
|
|
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">next_main_token</span>
|
|
<span class="c1"># NOTE: only one token's medusa logit will be written in.</span>
|
|
<span class="n">medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span>
|
|
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="n">medusa_logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_medusa_heads</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
|
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
|
|
<span class="o">-</span><span class="bp">self</span><span class="o">.</span>
|
|
<span class="n">num_draft_tokens</span><span class="p">:]]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
|
|
|
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">next_main_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">find_best_medusa_path</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generation_input_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">next_token_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">best_path_lengths</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">to_padded_tensor</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">nested</span><span class="o">.</span><span class="n">nested_tensor</span><span class="p">(</span><span class="n">next_main_tokens</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="c1">#FIXME end id padding.</span>
|
|
<span class="n">next_medusa_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter_medusa_logits</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">medusa_logits</span><span class="p">)</span>
|
|
<span class="n">next_medusa_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_next_medusa_tokens</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">next_medusa_logits</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span> <span class="o">=</span> <span class="n">next_medusa_tokens</span><span class="p">[:,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_tree_ids</span><span class="p">[</span>
|
|
<span class="o">-</span><span class="bp">self</span><span class="o">.</span>
|
|
<span class="n">num_draft_tokens</span><span class="p">:]]</span>
|
|
<span class="k">return</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.process_logits_including_draft">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.process_logits_including_draft">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">process_logits_including_draft</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span>
|
|
<span class="n">next_step_buffer</span><span class="p">):</span>
|
|
<span class="w"> </span><span class="sd">"""</span>
|
|
<span class="sd"> 1. Process logits to tokens and validate (Medusa) or process outputs (ReDrafter)</span>
|
|
<span class="sd"> 2. Extract early stop criteria here : self.accept_length</span>
|
|
<span class="sd"> 3. Update output ids : needs self.new_tokens and past_sequence_length</span>
|
|
<span class="sd"> 4. Get next input_ids : self.[new_tokens, accept_lengths, medusa_output_tokens]</span>
|
|
<span class="sd"> 5. Update KV cache : self.[sequence_length, num_draft_tokens]</span>
|
|
<span class="sd"> 6. Update sequence_length_buffer and past_kv_length</span>
|
|
<span class="sd"> """</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="kc">False</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="c1"># NOTE: this function call also updates self.[accept_lengths, new_tokens, medusa_output_tokens]</span>
|
|
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_decode_and_verify</span><span class="p">(</span>
|
|
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
|
|
<span class="n">last_draft_paths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_paths</span>
|
|
<span class="c1"># print(best_path, self.new_tokens, self.medusa_output_tokens)</span>
|
|
<span class="n">last_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="k">if</span> <span class="n">step</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
|
|
<span class="n">cur_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="c1"># buffers are swapped at this point</span>
|
|
<span class="n">last_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'next_draft_tokens'</span><span class="p">]</span>
|
|
<span class="n">new_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'draft_tokens'</span><span class="p">]</span>
|
|
<span class="n">last_draft_paths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s2">"next_draft_indices"</span><span class="p">]</span>
|
|
<span class="n">last_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'next_spec_decoding_generation_lengths'</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
|
|
<span class="n">cur_draft_tokens_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="s1">'spec_decoding_generation_lengths'</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
|
|
|
|
<span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span> <span class="o">=</span> <span class="n">process_redrafter_outputs</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">last_draft_tokens</span><span class="p">,</span> <span class="n">new_draft_tokens</span><span class="p">)</span>
|
|
<span class="c1"># NOTE: stop criteria</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"early_stop_check"</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">total_accept_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">medusa_should_stop</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span>
|
|
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">early_stop_criteria</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span>
|
|
<span class="n">should_stop</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="c1"># NOTE: self.accept_lengths are the lengths of accepted tokens in the current step</span>
|
|
<span class="c1"># NOTE: self.sequence_length_buffer = num_past_kv_cache (accepted) + accept_lengths</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"update_output_ids"</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">update_output_ids_by_offset</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">next_medusa_input_ids</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">best_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">best_path_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="n">accepted_draft_token_offsets</span><span class="p">,</span> <span class="n">packed_accepted_draft_tokens_indices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">locate_accepted_draft_tokens</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">best_path</span><span class="p">,</span> <span class="n">best_path_lengths</span><span class="p">,</span> <span class="n">last_draft_paths</span><span class="p">)</span>
|
|
<span class="c1"># update the KV cache</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"kv_update"</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
|
|
<span class="n">accepted_draft_token_offsets</span><span class="p">,</span>
|
|
<span class="n">packed_accepted_draft_tokens_indices</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span> <span class="n">last_draft_tokens_len</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">+</span> <span class="n">cur_draft_tokens_len</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="n">cur_draft_tokens_len</span> <span class="o">+</span> <span class="mi">1</span>
|
|
|
|
<span class="c1"># NOTE: set the accepted tokens for the last step.</span>
|
|
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
|
<span class="c1"># remove num_draft_tokens for next generation.</span>
|
|
<span class="c1"># Runtime: denotes kv cache length start positions.</span>
|
|
<span class="c1"># Output: denotes the length of sequence length (input ids + output ids)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span> <span class="o">-</span> <span class="n">last_draft_tokens_len</span>
|
|
|
|
<span class="k">if</span> <span class="n">next_step_buffer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">next_step_buffer</span><span class="p">[</span><span class="s1">'host_past_key_value_lengths'</span><span class="p">]</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
|
|
|
|
<span class="k">return</span> <span class="n">should_stop</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.handle_per_step">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.handle_per_step">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">handle_per_step</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="o">*</span><span class="p">,</span>
|
|
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_gen</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">next_step_tensors</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">RuntimeTensor</span><span class="p">],</span>
|
|
<span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
|
<span class="nb">print</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"=================================== STEP </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2"> =================================="</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
|
|
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
|
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span>
|
|
<span class="n">this_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
|
<span class="n">this_tgt_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
|
<span class="n">next_src_cache_indirection</span> <span class="o">=</span> <span class="n">cache_indirections</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
|
|
|
|
<span class="n">position_ids_raw</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'position_ids'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">skip_cross_attn_blocks</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'skip_cross_attn_blocks'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">language_adapter_routings</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'language_adapter_routings'</span><span class="p">,</span>
|
|
<span class="kc">None</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_context_inputs</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">pad_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">pad_id</span><span class="p">,</span>
|
|
<span class="n">eos_id</span><span class="o">=</span><span class="n">scfg</span><span class="o">.</span><span class="n">end_id</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">position_ids_raw</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># default iota position ids</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'position_ids'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># user input position ids</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">position_ids_raw</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">padded_position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">pad_sequence</span><span class="p">(</span>
|
|
<span class="n">position_ids_raw</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">padded_position_ids</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'last_token_ids'</span><span class="p">)</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
|
|
<span class="s1">'host_runtime_perf_knobs'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">host_context_progress</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="s1">'cuda'</span><span class="p">)</span>
|
|
|
|
<span class="n">ctx_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_context_shape_buffer</span><span class="p">(</span>
|
|
<span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
|
|
<span class="n">this_src_cache_indirection</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">context_runtime_perf_knobs</span><span class="p">,</span>
|
|
<span class="n">host_context_progress</span><span class="o">=</span><span class="n">host_context_progress</span><span class="p">,</span>
|
|
<span class="n">skip_cross_attn_blocks</span><span class="o">=</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span>
|
|
<span class="n">language_adapter_routings</span><span class="o">=</span><span class="n">language_adapter_routings</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">ctx_context</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">ctx_tensors</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">ctx_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
|
<span class="p">}</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
|
|
<span class="c1"># context mode, clean cuda graph instances</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">)]</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span> <span class="ow">and</span> <span class="kc">False</span><span class="p">:</span> <span class="c1"># TODO: after TRT bug is fixed</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_check_tensors</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
|
|
<span class="c1"># dynamic_decoder currently use torch's current stream, so must let TRT enqueue use same stream here</span>
|
|
<span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">cuda_stream</span>
|
|
<span class="n">instance_idx</span> <span class="o">=</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span>
|
|
<span class="n">instance_idx</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># launch cuda graph</span>
|
|
<span class="n">CUASSERT</span><span class="p">(</span>
|
|
<span class="n">cudart</span><span class="o">.</span><span class="n">cudaGraphLaunch</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">cuda_graph_instances</span><span class="p">[</span><span class="n">instance_idx</span><span class="p">],</span> <span class="n">stream</span><span class="p">))</span>
|
|
<span class="n">ok</span> <span class="o">=</span> <span class="kc">True</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">ok</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_run</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">stream</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">ok</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Executing TRT engine failed step=</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">!"</span><span class="p">)</span>
|
|
|
|
<span class="c1"># TODO: remove this Windows WAR after https://nvbugs/4460474 is fixed.</span>
|
|
<span class="k">if</span> <span class="n">platform</span><span class="o">.</span><span class="n">system</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"Windows"</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
|
|
|
<span class="n">context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span>
|
|
<span class="n">context_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
<span class="c1"># gather last token of context</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="c1"># reshape self.buffer['logits'] from [bs, max_context_length, vocab]</span>
|
|
<span class="c1"># to [1, bs * max_context_length, vocab]</span>
|
|
<span class="c1"># Note that the data are put in the buffer without padding although</span>
|
|
<span class="c1"># the allocated buffer has padding.</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">])</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">index_select</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span>
|
|
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
|
<span class="n">index</span><span class="o">=</span><span class="n">last_token_ids</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size_padded</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">beam_width</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span>
|
|
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_rnn_layers</span>
|
|
<span class="c1"># these tiled tensors are returned by handle_per_step(), so they can relay to the next generation calls</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">attention_mask</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">encoder_input_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">tasks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">tasks</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">tasks</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Move tiling before logit computing of context</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">:</span>
|
|
<span class="c1"># Note: this tiles both self attn cache and cross attn</span>
|
|
<span class="c1"># cache! both names contain "present_key_value"</span>
|
|
<span class="k">if</span> <span class="s2">"present_key_value"</span> <span class="ow">in</span> <span class="n">key</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="c1"># In the OOTB path, KV cache should be contiguously</span>
|
|
<span class="c1"># tiled since TRT engine allocates past_kv cache of</span>
|
|
<span class="c1"># length context_length, i.e., we need a buffer of</span>
|
|
<span class="c1"># shape (batch * beam, 2, heads, context_length, head_size).</span>
|
|
<span class="n">b</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">d</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>
|
|
<span class="n">numel</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">b</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="p">(</span><span class="n">max_context_length</span> <span class="o">+</span> <span class="n">step</span><span class="p">)</span> <span class="o">*</span> <span class="n">d</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">_contiguous_tile_beam_width</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">numel</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">],</span> <span class="n">beam_width</span><span class="p">)</span>
|
|
|
|
<span class="n">generation_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
|
|
<span class="n">generation_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
|
|
<span class="c1"># Initialize sequence_lengths (no paddings) for the generation phase.</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span> <span class="c1"># Medusa/ReDrafter has its own logic</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="c1"># to simplify some processing logic, always swap buffers after execution</span>
|
|
<span class="n">exchange_redrafter_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
|
|
|
<span class="c1"># NOTE: handle next step.</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">step</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
|
<span class="c1"># Set shape and address for the next step</span>
|
|
<span class="n">model_inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_generation_inputs</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
|
|
<span class="n">num_beams</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">position_ids_raw</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'position_ids'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">p</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">position_ids_raw</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'last_token_ids'</span><span class="p">)</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'attention_mask'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">gen_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">model_inputs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'host_runtime_perf_knobs'</span><span class="p">,</span>
|
|
<span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">host_context_progress</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Prepare for the next step, and always allocate 1 token slot.</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="c1"># Iterate to the next step in KV cache manager.</span>
|
|
<span class="c1"># Increase number of tokens for all unfinished sequences.</span>
|
|
<span class="c1"># And allocate new blocks if needed.</span>
|
|
<span class="c1"># We set this to False for all sequences, since we use only length criterion to stop now</span>
|
|
<span class="c1"># OPTIMIZE: find a better of adding multiple tokens for paged kv cache.</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"paged_kv_alloc"</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span> <span class="o">+</span>
|
|
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">></span> <span class="mi">0</span>
|
|
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="c1"># Allocate kv cache token slots for next step.</span>
|
|
<span class="c1"># Make sure there are always > (num_draft_tokens + 1) free token slots.</span>
|
|
<span class="c1"># Allocate (num_draft_tokens + 1) * 2 for safety as we don't know the current step or next step's accepted lengths.</span>
|
|
<span class="n">add_token_count</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_draft_tokens</span> <span class="o">+</span>
|
|
<span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="k">assert</span> <span class="n">add_token_count</span> <span class="o">></span> <span class="mi">0</span>
|
|
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">add_token_count</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">False</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"paged_kv_post_alloc"</span><span class="p">)</span>
|
|
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
|
|
<span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">get_block_offsets</span><span class="p">(</span>
|
|
<span class="n">beam_width</span><span class="p">)</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="n">host_cross_kv_cache_block_offsets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="n">next_context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">2</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_0</span>
|
|
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="n">cross_attention_mask_for_gen</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># cross_attention_mask_for_gen shape [batch_size, max_output_length, max_encoder_input_length]</span>
|
|
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span>
|
|
<span class="k">if</span> <span class="n">decode_step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">decode_step</span> <span class="o">+=</span> <span class="mi">1</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="p">[:,</span> <span class="p">(</span>
|
|
<span class="n">decode_step</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="p">:]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">cross_attention_mask_step</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="p">[:,</span> <span class="p">(</span>
|
|
<span class="n">decode_step</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span><span class="n">decode_step</span><span class="p">,</span> <span class="p">:]</span>
|
|
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_next_step_shape_buffer</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">position_ids</span><span class="p">,</span>
|
|
<span class="n">last_token_ids</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_step</span><span class="p">,</span>
|
|
<span class="n">next_src_cache_indirection</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">host_runtime_perf_knobs</span><span class="o">=</span><span class="n">gen_runtime_perf_knobs</span><span class="p">,</span>
|
|
<span class="n">host_context_progress</span><span class="o">=</span><span class="n">host_context_progress</span><span class="p">,</span>
|
|
<span class="n">skip_cross_attn_blocks</span><span class="o">=</span><span class="n">skip_cross_attn_blocks</span><span class="p">,</span>
|
|
<span class="n">language_adapter_routings</span><span class="o">=</span><span class="n">language_adapter_routings</span><span class="p">)</span>
|
|
|
|
<span class="c1"># there are some tensors created inside the _get_next_step_shape_buffer, not owned by any object</span>
|
|
<span class="c1"># needs to pro-long the life time of the tensors inside the next_step_tensors array</span>
|
|
<span class="c1"># otherwise, it maybe released before the next step actually enqueued</span>
|
|
<span class="c1"># one way to prolong it is to return the list, and destroy it in next step by assigning new values</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_push</span><span class="p">(</span><span class="s2">"_set_tensors"</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_tensors</span><span class="p">(</span><span class="n">next_context</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">nvtx</span><span class="o">.</span><span class="n">range_pop</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span> <span class="o">==</span> <span class="s2">"verbose"</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">print_context_info</span><span class="p">(</span>
|
|
<span class="n">next_context</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">next_context</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">context_1</span><span class="p">))</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_capture_cuda_graph_and_instantiate</span><span class="p">(</span>
|
|
<span class="n">next_context</span><span class="p">,</span> <span class="n">stream</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
|
|
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">logits</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_including_draft</span><span class="p">(</span>
|
|
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
|
|
<span class="k">elif</span> <span class="n">logits</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_logits_including_draft</span><span class="p">(</span>
|
|
<span class="n">step</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">logits_processor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
|
|
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">logits</span> <span class="o">=</span> <span class="n">logits_processor</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span>
|
|
<span class="n">logits</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="s1">'logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">logits</span>
|
|
<span class="c1"># [batch_size x beam_width, vocab_size_padded] -> [batch_size, beam_width, vocab_size_padded]</span>
|
|
<span class="n">next_token_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">decoder_logits_dtype</span><span class="p">)</span>
|
|
<span class="n">decode_step</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="n">max_context_length</span>
|
|
|
|
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span> <span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_data</span>
|
|
<span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span> <span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_data</span>
|
|
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_decoder</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
|
|
<span class="n">next_token_logits</span><span class="p">,</span> <span class="n">decode_step</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_bias_opt</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
|
<span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
|
|
<span class="n">max_stop_words_len</span><span class="p">,</span> <span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
|
|
<span class="n">max_bad_words_len</span><span class="p">,</span> <span class="n">this_src_cache_indirection</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_tokens</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">finished</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">log_probs_tiled</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">parent_ids</span><span class="p">,</span>
|
|
<span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_output_ids_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_seq_len_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_cum_log_probs_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_normed_scores_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_log_probs_cba</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_min_normed_scores</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_num_beams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_hyps_is_done</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">.</span><span class="n">use_beam_hyps</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">reorder_kv_cache_for_beam_search</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">stopping_criteria</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="c1"># keep the shape as same as huggingface stopping_criteria</span>
|
|
<span class="n">final_output_ids_</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">final_output_ids</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">should_stop</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">stopping_criteria</span><span class="p">(</span>
|
|
<span class="n">step</span><span class="p">,</span> <span class="n">final_output_ids_</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_is_profiling</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">context</span><span class="o">.</span><span class="n">report_to_profiler</span><span class="p">():</span>
|
|
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">"Runtime report to profiler failed."</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_insert_step_to_profiler</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_communicate_new_tokens</span><span class="p">(</span>
|
|
<span class="n">should_stop</span><span class="p">,</span> <span class="n">this_tgt_cache_indirection</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="p">(</span><span class="n">step</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
|
<span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">()):</span>
|
|
<span class="c1"># Free all blocks in all sequences.</span>
|
|
<span class="c1"># With in-flight batching and while loop we'll free some sequences, when they are done</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">step</span><span class="p">([</span><span class="kc">True</span><span class="p">]</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dump_debug_buffers</span><span class="p">(</span><span class="n">step</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">next_step_tensors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="n">name</span><span class="p">:</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to_torch</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">next_step_tensors</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
|
<span class="p">}</span>
|
|
|
|
<span class="k">return</span> <span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.dump_debug_buffers">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.dump_debug_buffers">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">dump_debug_buffers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="c1"># restricted written tensors according to filter</span>
|
|
<span class="n">debug_tensor_names</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
|
|
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">debug_tensor_names</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="nb">all</span><span class="p">([</span><span class="n">kk</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">kk</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_tensors_to_save</span><span class="p">]):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
|
|
|
|
<span class="n">debug_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s2">"tllm_debug/PP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_rank</span><span class="si">}</span><span class="s2">/TP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_rank</span><span class="si">}</span><span class="s2">/CP_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_rank</span><span class="si">}</span><span class="s2">"</span>
|
|
<span class="p">)</span>
|
|
<span class="n">debug_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">debug_buffer</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
|
<span class="c1"># convert tensor name to valid file name</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="s2">"Saving: "</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
|
|
<span class="n">fname</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"/"</span><span class="p">,</span> <span class="s2">"."</span><span class="p">)</span>
|
|
<span class="n">t</span> <span class="o">=</span> <span class="n">torch_to_numpy</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">float</span><span class="p">())</span>
|
|
<span class="n">np</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.npy"</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
|
|
|
|
<span class="n">txt_format</span> <span class="o">=</span> <span class="s2">"</span><span class="si">%d</span><span class="s2">"</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">]</span> <span class="k">else</span> <span class="s1">'</span><span class="si">%.18e</span><span class="s1">'</span>
|
|
<span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span>
|
|
<span class="n">debug_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">fname</span><span class="si">}</span><span class="s2">-step</span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s2">.txt"</span><span class="p">,</span>
|
|
<span class="n">t</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="c1"># savetxt accepts 2 dims only</span>
|
|
<span class="n">fmt</span><span class="o">=</span><span class="n">txt_format</span><span class="p">)</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.decode_regular">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_regular">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">decode_regular</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="o">*</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">outputs_generation_logits</span> <span class="o">=</span> <span class="p">[]</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
|
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'output_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_log_probs</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'log_probs'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">log_probs</span>
|
|
<span class="k">if</span> <span class="n">scfg</span><span class="o">.</span><span class="n">output_cum_log_probs</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'cum_log_probs'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cum_log_probs</span>
|
|
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span>
|
|
<span class="s1">'sequence_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'steps_to_finish'</span><span class="p">]</span> <span class="o">=</span> <span class="n">num_steps</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'medusa_output_tokens'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_tokens</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'accept_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">accept_lengths</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_temperature</span> <span class="o">!=</span> <span class="mf">0.0</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'medusa_output_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">medusa_output_logits</span>
|
|
<span class="k">return</span> <span class="n">outputs</span>
|
|
|
|
<span class="n">benchmark_profiler</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'benchmark_profiler'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
|
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="mi">0</span>
|
|
|
|
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">is_recording_perf_profile</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">runtime</span><span class="o">.</span><span class="n">_set_profiler</span><span class="p">()</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler_obj</span><span class="p">,</span> <span class="n">step_count</span><span class="p">):</span>
|
|
<span class="k">if</span> <span class="n">benchmark_profiler_obj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">'last_token'</span><span class="p">)</span>
|
|
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">record_elapsed_time</span><span class="p">(</span>
|
|
<span class="s1">'first_token'</span><span class="p">,</span> <span class="s1">'last_token'</span><span class="p">,</span> <span class="s1">'generation_time'</span><span class="p">)</span>
|
|
<span class="n">benchmark_profiler_obj</span><span class="o">.</span><span class="n">add_aux_info</span><span class="p">(</span><span class="s1">'generation_step_count'</span><span class="p">,</span>
|
|
<span class="n">step_count</span><span class="p">)</span>
|
|
|
|
<span class="c1"># prepare cross attention mask.</span>
|
|
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span> <span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_cross_attention_mask</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="c1"># When we use plugin, the data type of cross_attention_mask is bool.</span>
|
|
<span class="c1"># When we don't use plugin, the data type of cross_attention_mask is int32</span>
|
|
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_context</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
|
|
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="n">cross_attention_mask_for_gen</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
|
|
|
|
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
|
|
|
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
|
|
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="o">=</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="o">=</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="o">=</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="o">=</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="o">=</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_gen</span><span class="o">=</span><span class="n">cross_attention_mask_for_gen</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">next_step_tensors</span><span class="o">=</span><span class="n">next_step_tensors</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="n">benchmark_profiler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">benchmark_profiler</span><span class="o">.</span><span class="n">record_cuda_event</span><span class="p">(</span><span class="s1">'first_token'</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">generation_phase_step_count</span> <span class="o">=</span> <span class="n">generation_phase_step_count</span> <span class="o">+</span> <span class="mi">1</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
|
|
<span class="n">outputs_generation_logits</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">generation_logits</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
|
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="c1"># just hack away for now</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_ids</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="n">final_output_ids</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span>
|
|
<span class="n">max_seq_length</span> <span class="o">-</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_draft_tokens</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">,</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">final_output_ids</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
|
<span class="k">return</span> <span class="n">outputs</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="kc">None</span>
|
|
|
|
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">,</span> <span class="s2">"the custom decoder doesn't support medusa/redrafter."</span>
|
|
|
|
<span class="n">profile_fn</span><span class="p">(</span><span class="n">benchmark_profiler</span><span class="p">,</span> <span class="n">generation_phase_step_count</span><span class="p">)</span>
|
|
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="n">final_output_ids</span>
|
|
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_last_pp_rank</span><span class="p">():</span>
|
|
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_generation_logits</span> <span class="ow">or</span> <span class="n">output_generation_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'generation_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_generation_logits</span>
|
|
<span class="k">return</span> <span class="n">outputs</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="kc">None</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.decode_stream">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_stream">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">decode_stream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="o">*</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">cache_indirections</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">host_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">get_outputs_dict</span><span class="p">(</span><span class="n">output_ids</span><span class="p">):</span>
|
|
<span class="n">outputs</span> <span class="o">=</span> <span class="p">{}</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'output_ids'</span><span class="p">]</span> <span class="o">=</span> <span class="n">output_ids</span>
|
|
<span class="k">if</span> <span class="n">output_sequence_lengths</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span>
|
|
<span class="s1">'sequence_lengths'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length_buffer</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">beam_width</span><span class="p">])</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">:</span>
|
|
<span class="n">outputs</span><span class="p">[</span><span class="s1">'context_logits'</span><span class="p">]</span> <span class="o">=</span> <span class="n">outputs_context_logits</span>
|
|
<span class="k">return</span> <span class="n">outputs</span>
|
|
|
|
<span class="c1"># prepare cross attention mask.</span>
|
|
<span class="n">cross_attention_mask_for_context</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="n">cross_attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span> <span class="n">cross_attention_mask_for_gen</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_prepare_cross_attention_mask</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">cross_attention_mask</span><span class="p">)</span>
|
|
|
|
<span class="n">next_step_tensors</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_new_tokens</span><span class="p">):</span>
|
|
|
|
<span class="n">should_stop</span><span class="p">,</span> <span class="n">next_step_tensors</span><span class="p">,</span> <span class="n">tasks</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">context_logits</span><span class="p">,</span> <span class="n">generation_logits</span><span class="p">,</span> <span class="n">encoder_input_lengths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">handle_per_step</span><span class="p">(</span>
|
|
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
|
|
<span class="n">step</span><span class="o">=</span><span class="n">step</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">kv_cache_block_offsets</span><span class="o">=</span><span class="n">kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_kv_cache_block_offsets</span><span class="o">=</span><span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">cross_kv_cache_block_offsets</span><span class="o">=</span><span class="n">cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="o">=</span>
|
|
<span class="n">host_cross_kv_cache_block_offsets</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="o">=</span>
|
|
<span class="n">cross_attention_mask_for_context</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask_for_gen</span><span class="o">=</span><span class="n">cross_attention_mask_for_gen</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">next_step_tensors</span><span class="o">=</span><span class="n">next_step_tensors</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">if</span> <span class="n">step</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">outputs_context_logits</span> <span class="o">=</span> <span class="n">context_logits</span>
|
|
<span class="k">if</span> <span class="n">should_stop</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">in_progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="n">final_output_ids</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="kc">None</span>
|
|
|
|
<span class="k">if</span> <span class="n">should_stop</span><span class="o">.</span><span class="n">item</span><span class="p">():</span>
|
|
<span class="k">return</span>
|
|
|
|
<span class="n">final_output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">finalize_decoder</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span> <span class="n">scfg</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">is_first_pp_rank</span><span class="p">():</span>
|
|
<span class="k">if</span> <span class="n">return_dict</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="n">get_outputs_dict</span><span class="p">(</span><span class="n">final_output_ids</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="n">final_output_ids</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">yield</span> <span class="kc">None</span></div>
|
|
|
|
|
|
<div class="viewcode-block" id="GenerationSession.decode_batch">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode_batch">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">decode_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
|
|
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">input_ids</span><span class="p">,</span> <span class="n">context_lengths</span> <span class="o">=</span> <span class="n">_prepare_input_ids</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">sampling_config</span><span class="p">,</span>
|
|
<span class="n">streaming</span><span class="o">=</span><span class="n">streaming</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
|
|
|
|
|
|
<span class="c1"># As dynamic_decoder uses torch's current stream, we must ensure it runs on the same stream that</span>
|
|
<span class="c1"># dynamic_decoder was set up with</span>
|
|
<div class="viewcode-block" id="GenerationSession.decode">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.GenerationSession.decode">[docs]</a>
|
|
<span class="nd">@cuda_stream_guard</span>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">stop_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">bad_words_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">output_sequence_lengths</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">return_dict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="p">:</span> <span class="n">StoppingCriteria</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="p">:</span> <span class="n">LogitsProcessor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="n">scfg</span> <span class="o">=</span> <span class="n">sampling_config</span>
|
|
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">beam_width</span> <span class="o">=</span> <span class="n">scfg</span><span class="o">.</span><span class="n">num_beams</span>
|
|
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="n">host_context_lengths</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
|
<span class="k">assert</span> <span class="n">batch_size</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> \
|
|
<span class="s2">"Given batch size is different from the one used in setup(),"</span> \
|
|
<span class="s2">"rerun the setup function with the new batch size to avoid buffer overflow."</span>
|
|
<span class="k">assert</span> <span class="n">max_context_length</span> <span class="o"><=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_context_length</span><span class="p">,</span> \
|
|
<span class="s2">"Given input length is large then the one used in setup(),"</span> \
|
|
<span class="s2">"rerun the setup function with the new max_context_length to avoid buffer overflow."</span>
|
|
<span class="k">assert</span> <span class="n">beam_width</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">beam_width</span><span class="p">,</span> \
|
|
<span class="s2">"Given beam width is different from the one used in setup(),"</span> \
|
|
<span class="s2">"rerun the setup function with the new beam width to avoid buffer overflow."</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span> <span class="o"><=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> \
|
|
<span class="s2">"Given sink token length is larger than shortest context length,"</span> \
|
|
<span class="s2">"rerun the setup function with a smaller sink token length."</span>
|
|
<span class="n">ite</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># index of local batches, will always be 0 if pp_size = 1</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
|
|
<span class="k">assert</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span>
|
|
<span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Packed 2D input must have shape [1, <sum of input lengths>]"</span>
|
|
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">__setup_decoder</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">scfg</span><span class="p">,</span> <span class="n">host_context_lengths</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_allocated</span><span class="p">:</span>
|
|
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">'Buffer not allocated, please call setup first!'</span><span class="p">)</span>
|
|
|
|
<span class="n">sequence_limit_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Sequence_lengths for the dynamic decoder still has the input paddings.</span>
|
|
<span class="n">sequence_lengths</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
|
<span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
|
|
<span class="n">cache_indirections</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="p">),</span>
|
|
<span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span>
|
|
<span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="p">),</span>
|
|
<span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
|
|
<span class="p">]</span> <span class="c1"># ping-pong buffers</span>
|
|
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">has_pp</span><span class="p">():</span>
|
|
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_length</span><span class="p">)</span>
|
|
<span class="n">hidden_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
|
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">))</span>
|
|
|
|
<span class="c1"># Init KV cache block manager</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">has_attn_layers</span><span class="p">:</span>
|
|
<span class="n">num_blocks</span><span class="p">,</span> <span class="n">max_blocks_per_seq</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'host_kv_cache_pool_pointers'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">get_kv_cache_pool_pointers</span><span class="p">(</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'host_kv_cache_pool_mapping'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">pool_mapping</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span> <span class="o">=</span> <span class="n">PoolsKVCacheManager</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_memory_pool_allocator</span><span class="o">.</span><span class="n">pools_metadata</span><span class="p">,</span>
|
|
<span class="n">max_blocks_per_seq</span><span class="p">,</span>
|
|
<span class="n">num_blocks</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="n">max_attention_window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window_size</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">sink_token_len</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_num_blocks</span><span class="p">,</span> <span class="n">max_cross_blocks_per_seq</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_num_paged_blocks</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span> <span class="n">sink_token_length</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_pointers'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">get_kv_cache_pool_pointers</span><span class="p">(</span>
|
|
<span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span>
|
|
<span class="sa">f</span><span class="s1">'host_cross_kv_cache_pool_mapping'</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">pool_mapping</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span> <span class="o">=</span> <span class="n">PoolsKVCacheManager</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">_cross_memory_pool_allocator</span><span class="o">.</span><span class="n">pools_metadata</span><span class="p">,</span>
|
|
<span class="n">max_cross_blocks_per_seq</span><span class="p">,</span>
|
|
<span class="n">cross_num_blocks</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="n">max_attention_window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">sink_token_len</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">)</span>
|
|
|
|
<span class="c1"># Add sequences to the manager</span>
|
|
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
|
|
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span>
|
|
<span class="n">generation_sequence</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">)</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span><span class="p">:</span>
|
|
<span class="n">cross_generation_sequence</span> <span class="o">=</span> <span class="n">GenerationSequence</span><span class="p">(</span><span class="n">seq_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">,</span>
|
|
<span class="n">batch_idx</span><span class="o">=</span><span class="n">bi</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">cross_pools_kv_cache_manager</span><span class="o">.</span><span class="n">add_sequence</span><span class="p">(</span>
|
|
<span class="n">cross_generation_sequence</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">encoder_max_input_length</span><span class="p">,</span>
|
|
<span class="n">always_share_across_beam</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
<span class="c1"># cross attention paged kv cache should always share the context blocks across beams</span>
|
|
<span class="c1"># due to the fact that we are not adding new key/value cache to cross kv in generation</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_medusa_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_redrafter_mode</span><span class="p">:</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
|
<span class="c1"># Since torch does not support fp8 now, using int8 here.</span>
|
|
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">int8</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensor_dtype</span><span class="p">(</span>
|
|
<span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">history_max_seq_length</span> <span class="o">=</span> <span class="p">[</span><span class="n">max_context_length</span><span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span> <span class="o">=</span> <span class="n">KVCacheUpdater</span><span class="p">()</span>
|
|
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">cross_attention</span>
|
|
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_gpt_attention_plugin</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">paged_kv_cache</span><span class="p">:</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_paged_kv_cache</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pools_kv_cache_manager</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'host_kv_cache_pool_pointers'</span><span class="p">])</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">past_key_value_list</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="p">[</span><span class="sa">f</span><span class="s1">'present_key_value_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">'</span><span class="p">]</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_layer</span><span class="p">)</span>
|
|
<span class="p">]</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_updater</span><span class="o">.</span><span class="n">init_linear_kv_cache</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_num_heads_kv</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_size</span><span class="p">,</span>
|
|
<span class="n">kv_cache_type</span><span class="p">,</span> <span class="n">past_key_value_list</span><span class="p">)</span>
|
|
|
|
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">if</span> <span class="n">stop_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">stop_words_list</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">stop_words_list</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">max_stop_words_len</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
|
|
<span class="n">stop_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">max_stop_words_len</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">stop_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
|
|
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_stop_words_len</span> <span class="o">*</span> <span class="n">stop_words_list</span><span class="o">.</span><span class="n">element_size</span><span class="p">(</span>
|
|
<span class="p">)</span>
|
|
<span class="n">stop_words_list_ptrs</span> <span class="o">=</span> <span class="n">stop_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">stop_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">stop_words_list_ptrs</span><span class="p">,</span> <span class="n">stop_words_lens</span><span class="p">,</span>
|
|
<span class="n">max_stop_words_len</span><span class="p">)</span>
|
|
|
|
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="kc">None</span>
|
|
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="mi">0</span>
|
|
<span class="k">if</span> <span class="n">bad_words_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
|
<span class="n">bad_words_list</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">bad_words_list</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
|
|
<span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">max_bad_words_len</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
|
|
<span class="n">bad_words_lens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">),</span>
|
|
<span class="n">max_bad_words_len</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">bi</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">bad_words_list_ptrs</span><span class="p">[</span><span class="n">bi</span><span class="p">]</span> <span class="o">=</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">data_ptr</span><span class="p">(</span>
|
|
<span class="p">)</span> <span class="o">+</span> <span class="n">bi</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">max_bad_words_len</span> <span class="o">*</span> <span class="n">bad_words_list</span><span class="o">.</span><span class="n">element_size</span><span class="p">()</span>
|
|
<span class="n">bad_words_list_ptrs</span> <span class="o">=</span> <span class="n">bad_words_list_ptrs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">bad_words_data</span> <span class="o">=</span> <span class="p">(</span><span class="n">bad_words_list_ptrs</span><span class="p">,</span> <span class="n">bad_words_lens</span><span class="p">,</span>
|
|
<span class="n">max_bad_words_len</span><span class="p">)</span>
|
|
|
|
<span class="c1"># start context phase</span>
|
|
<span class="k">if</span> <span class="n">streaming</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_stream</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">output_sequence_lengths</span><span class="o">=</span><span class="n">output_sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">return_dict</span><span class="o">=</span><span class="n">return_dict</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="o">=</span><span class="n">cross_attention_mask</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode_regular</span><span class="p">(</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
|
|
<span class="n">scfg</span><span class="o">=</span><span class="n">scfg</span><span class="p">,</span>
|
|
<span class="n">sequence_lengths</span><span class="o">=</span><span class="n">sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="o">=</span><span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">host_context_lengths</span><span class="o">=</span><span class="n">host_context_lengths</span><span class="p">,</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_context_length</span><span class="p">,</span>
|
|
<span class="n">beam_width</span><span class="o">=</span><span class="n">beam_width</span><span class="p">,</span>
|
|
<span class="n">cache_indirections</span><span class="o">=</span><span class="n">cache_indirections</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
|
|
<span class="n">hidden_states</span><span class="o">=</span><span class="n">hidden_states</span><span class="p">,</span>
|
|
<span class="n">prompt_embedding_table</span><span class="o">=</span><span class="n">prompt_embedding_table</span><span class="p">,</span>
|
|
<span class="n">tasks</span><span class="o">=</span><span class="n">tasks</span><span class="p">,</span>
|
|
<span class="n">prompt_vocab_size</span><span class="o">=</span><span class="n">prompt_vocab_size</span><span class="p">,</span>
|
|
<span class="n">ite</span><span class="o">=</span><span class="n">ite</span><span class="p">,</span>
|
|
<span class="n">sequence_limit_lengths</span><span class="o">=</span><span class="n">sequence_limit_lengths</span><span class="p">,</span>
|
|
<span class="n">stop_words_data</span><span class="o">=</span><span class="n">stop_words_data</span><span class="p">,</span>
|
|
<span class="n">bad_words_data</span><span class="o">=</span><span class="n">bad_words_data</span><span class="p">,</span>
|
|
<span class="n">output_sequence_lengths</span><span class="o">=</span><span class="n">output_sequence_lengths</span><span class="p">,</span>
|
|
<span class="n">output_generation_logits</span><span class="o">=</span><span class="n">output_generation_logits</span><span class="p">,</span>
|
|
<span class="n">return_dict</span><span class="o">=</span><span class="n">return_dict</span><span class="p">,</span>
|
|
<span class="n">encoder_output</span><span class="o">=</span><span class="n">encoder_output</span><span class="p">,</span>
|
|
<span class="n">encoder_input_lengths</span><span class="o">=</span><span class="n">encoder_input_lengths</span><span class="p">,</span>
|
|
<span class="n">stopping_criteria</span><span class="o">=</span><span class="n">stopping_criteria</span><span class="p">,</span>
|
|
<span class="n">logits_processor</span><span class="o">=</span><span class="n">logits_processor</span><span class="p">,</span>
|
|
<span class="n">cross_attention_mask</span><span class="o">=</span><span class="n">cross_attention_mask</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
|
|
<span class="p">)</span></div>
|
|
</div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="ChatGLMGenerationSession">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.ChatGLMGenerationSession">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">ChatGLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
|
<span class="n">engine_buffer</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
|
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
|
|
<span class="n">model_config</span><span class="p">,</span>
|
|
<span class="n">engine_buffer</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="p">,</span>
|
|
<span class="n">debug_mode</span><span class="p">,</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="p">,</span>
|
|
<span class="n">cuda_graph_mode</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">,</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> <span class="kc">None</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_context_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
|
|
<span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
|
|
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'max_context_length'</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
<span class="n">input_lengths_acc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
|
|
<span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">IntTensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">context_lengths</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span><span class="n">input_lengths_acc</span><span class="p">[</span>
|
|
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
|
<span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span>
|
|
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">input_lengths_acc</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
|
|
<span class="c1"># specialization for GLM series models</span>
|
|
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
|
|
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
|
|
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
|
|
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">][</span>
|
|
<span class="mi">0</span><span class="p">:</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">kwargs</span><span class="p">[</span>
|
|
<span class="s2">"input_ids"</span><span class="p">][</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="p">):</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span><span class="p">])</span> <span class="o">+</span>
|
|
<span class="n">length</span><span class="p">]</span>
|
|
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
|
|
<span class="p">]</span>
|
|
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
|
|
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="nb">sum</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span> <span class="o">-</span>
|
|
<span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">],</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">position_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">)</span>
|
|
|
|
<span class="c1"># specialization for GLM series models</span>
|
|
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">50256</span><span class="p">,</span> <span class="mi">50259</span><span class="p">]:</span>
|
|
<span class="k">if</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"pad_id"</span><span class="p">]</span> <span class="o">==</span> <span class="mi">50256</span><span class="p">:</span> <span class="c1"># glm_2b / glm_10b</span>
|
|
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50260</span><span class="p">,</span> <span class="mi">50264</span><span class="p">,</span> <span class="mi">50263</span><span class="p">]</span>
|
|
<span class="k">else</span><span class="p">:</span> <span class="c1"># glm_10b_chinese / glm_large_chinese</span>
|
|
<span class="n">mask_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50003</span><span class="p">,</span> <span class="mi">50008</span><span class="p">,</span> <span class="mi">50009</span><span class="p">]</span>
|
|
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="o">=</span> \
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">input_ids</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">][</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">mask_index</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">input_ids</span> <span class="o">==</span> <span class="nb">id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">int</span><span class="p">()</span> <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">mask_ids</span>
|
|
<span class="p">]</span>
|
|
<span class="n">tail_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="n">max_context_length</span><span class="p">])</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="n">mask_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tail_index</span><span class="p">)</span>
|
|
<span class="n">mask_index</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">mask_index</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">()</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">mask_index</span><span class="p">)</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">length</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">2</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
|
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
|
|
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
|
|
<span class="n">context_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">perf_knob_tensor_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
|
|
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="s1">'position_ids'</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
|
|
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
|
|
<span class="s1">'host_runtime_perf_knobs'</span><span class="p">:</span> <span class="n">context_runtime_perf_knobs</span>
|
|
<span class="p">}</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">inputs</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
|
<span class="k">return</span> <span class="n">inputs</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_prepare_generation_inputs</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">context_lengths</span><span class="p">,</span>
|
|
<span class="n">use_gpt_attention_plugin</span><span class="p">,</span>
|
|
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
|
|
<span class="n">step</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'step'</span><span class="p">)</span>
|
|
<span class="n">num_beams</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">'num_beams'</span><span class="p">)</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">context_lengths</span><span class="p">)</span>
|
|
|
|
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
|
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
|
<span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_beams</span>
|
|
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="n">tile_size</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="n">tile_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tile_size</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
|
<span class="n">new_tensor</span> <span class="o">=</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">new_shape</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span>
|
|
<span class="k">return</span> <span class="n">new_tensor</span>
|
|
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">2</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width_chatglm</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">int</span><span class="p">()</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">position_ids</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">position_ids</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="n">data</span> <span class="o">=</span> <span class="p">[]</span>
|
|
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># specialization for GLM series models</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_index_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
|
|
<span class="k">else</span><span class="p">:</span>
|
|
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
|
|
<span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">([[</span><span class="n">context_lengths</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">num_beams</span><span class="p">]</span> <span class="o">-</span> <span class="mi">2</span><span class="p">],</span>
|
|
<span class="p">[</span><span class="n">step</span> <span class="o">+</span> <span class="mi">2</span><span class="p">]])</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span>
|
|
<span class="n">position_ids</span> <span class="o">=</span> <span class="n">_tile_beam_width</span><span class="p">(</span><span class="n">position_ids</span><span class="p">,</span> <span class="n">num_beams</span><span class="p">)</span>
|
|
|
|
<span class="n">perf_knob_tensor_size</span> <span class="o">=</span> <span class="mi">16</span>
|
|
<span class="n">generation_runtime_perf_knobs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span>
|
|
<span class="n">perf_knob_tensor_size</span><span class="p">,</span>
|
|
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
|
|
|
<span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span>
|
|
<span class="s1">'position_ids'</span><span class="p">:</span> <span class="n">position_ids</span><span class="p">,</span>
|
|
<span class="s1">'last_token_ids'</span><span class="p">:</span> <span class="n">last_token_ids</span><span class="p">,</span>
|
|
<span class="s1">'host_runtime_perf_knobs'</span><span class="p">:</span> <span class="n">generation_runtime_perf_knobs</span>
|
|
<span class="p">}</span>
|
|
<span class="k">if</span> <span class="ow">not</span> <span class="n">use_gpt_attention_plugin</span><span class="p">:</span>
|
|
<span class="n">attention_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
|
<span class="n">inputs</span><span class="p">[</span><span class="s1">'attention_mask'</span><span class="p">]</span> <span class="o">=</span> <span class="n">attention_mask</span>
|
|
<span class="k">return</span> <span class="n">inputs</span></div>
|
|
|
|
|
|
|
|
<div class="viewcode-block" id="QWenForCausalLMGenerationSession">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession">[docs]</a>
|
|
<span class="k">class</span><span class="w"> </span><span class="nc">QWenForCausalLMGenerationSession</span><span class="p">(</span><span class="n">GenerationSession</span><span class="p">):</span>
|
|
|
|
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">model_config</span><span class="p">:</span> <span class="n">ModelConfig</span><span class="p">,</span>
|
|
<span class="n">engine_buffer</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="p">:</span> <span class="n">Mapping</span><span class="p">,</span>
|
|
<span class="n">debug_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
|
<span class="n">global_max_input_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span><span class="p">,</span>
|
|
<span class="n">global_max_output_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span>
|
|
<span class="n">engine_buffer</span><span class="p">,</span>
|
|
<span class="n">mapping</span><span class="p">,</span>
|
|
<span class="n">debug_mode</span><span class="p">,</span>
|
|
<span class="n">debug_tensors_to_save</span><span class="o">=</span><span class="n">debug_tensors_to_save</span><span class="p">,</span>
|
|
<span class="n">cuda_graph_mode</span><span class="o">=</span><span class="n">cuda_graph_mode</span><span class="p">,</span>
|
|
<span class="n">stream</span><span class="o">=</span><span class="n">stream</span><span class="p">)</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_input_length</span> <span class="o">=</span> <span class="n">global_max_input_length</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">=</span> <span class="n">global_max_output_length</span>
|
|
|
|
<div class="viewcode-block" id="QWenForCausalLMGenerationSession.generate">
|
|
<a class="viewcode-back" href="../../../python-api/tensorrt_llm.runtime.html#tensorrt_llm.runtime.QWenForCausalLMGenerationSession.generate">[docs]</a>
|
|
<span class="k">def</span><span class="w"> </span><span class="nf">generate</span><span class="p">(</span>
|
|
<span class="bp">self</span><span class="p">,</span>
|
|
<span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
|
|
<span class="n">sampling_config</span><span class="p">:</span> <span class="n">SamplingConfig</span><span class="p">,</span>
|
|
<span class="n">max_new_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
|
<span class="n">runtime_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
|
<span class="p">):</span>
|
|
<span class="n">max_input_length</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">input_lengths</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
|
<span class="n">max_new_tokens</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">max_new_tokens</span><span class="p">,</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">global_max_output_length</span> <span class="o">-</span> <span class="n">max_input_length</span><span class="p">)</span>
|
|
<span class="c1"># setup batch_size, max_input_length, max_output_len</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">setup</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">input_lengths</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
|
<span class="n">max_context_length</span><span class="o">=</span><span class="n">max_input_length</span><span class="p">,</span>
|
|
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_new_tokens</span><span class="p">)</span>
|
|
<span class="n">output_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">,</span> <span class="n">sampling_config</span><span class="p">)</span>
|
|
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
|
|
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
|
<span class="k">if</span> <span class="n">runtime_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
|
<span class="n">outputs</span> <span class="o">=</span> <span class="n">output_ids</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span>
|
|
<span class="k">return</span> <span class="n">outputs</span></div>
|
|
</div>
|
|
|
|
</pre></div>
|
|
|
|
</article>
|
|
|
|
|
|
|
|
|
|
|
|
<footer class="prev-next-footer d-print-none">
|
|
|
|
<div class="prev-next-area">
|
|
</div>
|
|
</footer>
|
|
|
|
</div>
|
|
|
|
|
|
|
|
<div class="bd-sidebar-secondary"></div>
|
|
|
|
|
|
|
|
|
|
|
|
</div>
|
|
<footer class="bd-footer-content">
|
|
|
|
</footer>
|
|
|
|
</main>
|
|
</div>
|
|
</div>
|
|
|
|
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
|
<script defer src="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
|
<script defer src="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
|
|
|
<footer class="bd-footer">
|
|
<div class="bd-footer__inner bd-page-width">
|
|
|
|
<div class="footer-items__start">
|
|
|
|
<div class="footer-item">
|
|
<a class="footer-brand logo" href="https://www.nvidia.com">
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
|
|
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
|
|
</a></div>
|
|
|
|
<div class="footer-item">
|
|
|
|
<div class="footer-links">
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Manage My Privacy</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/preferences/start/">Do Not Sell or Share My Data</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
|
|
|
|
|
|
|
|
|
|
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
|
|
|
|
|
|
|
|
</div>
|
|
</div>
|
|
|
|
<div class="footer-item">
|
|
|
|
|
|
|
|
|
|
<p class="copyright">
|
|
|
|
Copyright © 2025, NVidia.
|
|
<br/>
|
|
|
|
</p>
|
|
</div>
|
|
|
|
<div class="footer-item">
|
|
<div class="extra_footer">
|
|
|
|
<p>Last updated on June 03, 2025.</p>
|
|
|
|
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/9ae2ce6">9ae2ce6</a>.</p>
|
|
|
|
</div></div>
|
|
|
|
</div>
|
|
|
|
|
|
|
|
</div>
|
|
|
|
</footer>
|
|
</body>
|
|
</html> |