mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
8806 lines
1.1 MiB
8806 lines
1.1 MiB
|
||
|
||
<!DOCTYPE html>
|
||
|
||
|
||
<html lang="en" data-content_root="../../" >
|
||
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>tensorrt_llm.functional — TensorRT-LLM</title>
|
||
|
||
|
||
|
||
<script data-cfasync="false">
|
||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||
</script>
|
||
<!--
|
||
this give us a css class that will be invisible only if js is disabled
|
||
-->
|
||
<noscript>
|
||
<style>
|
||
.pst-js-only { display: none !important; }
|
||
|
||
</style>
|
||
</noscript>
|
||
|
||
<!-- Loaded before other Sphinx assets -->
|
||
<link href="../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||
<link href="../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||
|
||
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=8f2a1f02" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/styles/nvidia-sphinx-theme.css?v=df3ac72c" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css?v=76b2166b" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/autodoc_pydantic.css" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css?v=13237357" />
|
||
<link rel="stylesheet" type="text/css" href="../../_static/custom.css?v=95073da6" />
|
||
|
||
<!-- So that users can add custom icons -->
|
||
<script src="../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||
<link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||
<link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||
|
||
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
|
||
<script>let toggleHintShow = 'Click to show';</script>
|
||
<script>let toggleHintHide = 'Click to hide';</script>
|
||
<script>let toggleOpenOnPrint = 'true';</script>
|
||
<script src="../../_static/togglebutton.js?v=4a39c7ea"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/functional';</script>
|
||
<script>
|
||
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
|
||
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
|
||
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.1.0rc3';
|
||
DOCUMENTATION_OPTIONS.show_version_warning_banner =
|
||
false;
|
||
</script>
|
||
<link rel="icon" href="../../_static/favicon.png"/>
|
||
<link rel="index" title="Index" href="../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../search.html" />
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<meta name="docsearch:language" content="en"/>
|
||
<meta name="docsearch:version" content="1.1.0rc3" />
|
||
|
||
|
||
</head>
|
||
|
||
|
||
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
||
|
||
|
||
|
||
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
||
|
||
<div id="pst-scroll-pixel-helper"></div>
|
||
|
||
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||
|
||
|
||
<dialog id="pst-search-dialog">
|
||
|
||
<form class="bd-search d-flex align-items-center"
|
||
action="../../search.html"
|
||
method="get">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<input type="search"
|
||
class="form-control"
|
||
name="q"
|
||
placeholder="Search the docs ..."
|
||
aria-label="Search the docs ..."
|
||
autocomplete="off"
|
||
autocorrect="off"
|
||
autocapitalize="off"
|
||
spellcheck="false"/>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||
</form>
|
||
</dialog>
|
||
|
||
<div class="pst-async-banner-revealer d-none">
|
||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||
</div>
|
||
|
||
|
||
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
||
<div class="bd-header__inner bd-page-width">
|
||
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
|
||
<span class="fa-solid fa-bars"></span>
|
||
</button>
|
||
|
||
|
||
<div class="col-lg-3 navbar-header-items__start">
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../../index.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
|
||
|
||
|
||
<p class="title logo__title">TensorRT-LLM</p>
|
||
|
||
</a></div>
|
||
|
||
</div>
|
||
|
||
<div class="col-lg-9 navbar-header-items">
|
||
|
||
<div class="me-auto navbar-header-items__center">
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
<div class="version-switcher__container dropdown pst-js-only">
|
||
<button id="pst-version-switcher-button-2"
|
||
type="button"
|
||
class="version-switcher__button btn btn-sm dropdown-toggle"
|
||
data-bs-toggle="dropdown"
|
||
aria-haspopup="listbox"
|
||
aria-controls="pst-version-switcher-list-2"
|
||
aria-label="Version switcher list"
|
||
>
|
||
Choose version <!-- this text may get changed later by javascript -->
|
||
<span class="caret"></span>
|
||
</button>
|
||
<div id="pst-version-switcher-list-2"
|
||
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
||
role="listbox" aria-labelledby="pst-version-switcher-button-2">
|
||
<!-- dropdown will be populated by javascript on page load -->
|
||
</div>
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="navbar-header-items__end">
|
||
|
||
<div class="navbar-item navbar-persistent--container">
|
||
|
||
|
||
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
</div>
|
||
|
||
|
||
<div class="navbar-item">
|
||
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="navbar-persistent--mobile">
|
||
|
||
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</header>
|
||
|
||
|
||
<div class="bd-container">
|
||
<div class="bd-container__inner bd-page-width">
|
||
|
||
|
||
|
||
<dialog id="pst-primary-sidebar-modal"></dialog>
|
||
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../../index.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
|
||
|
||
|
||
<p class="title logo__title">TensorRT-LLM</p>
|
||
|
||
</a>
|
||
|
||
|
||
|
||
<div class="sidebar-header-items sidebar-primary__section">
|
||
|
||
|
||
<div class="sidebar-header-items__center">
|
||
|
||
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
<div class="version-switcher__container dropdown pst-js-only">
|
||
<button id="pst-version-switcher-button-3"
|
||
type="button"
|
||
class="version-switcher__button btn btn-sm dropdown-toggle"
|
||
data-bs-toggle="dropdown"
|
||
aria-haspopup="listbox"
|
||
aria-controls="pst-version-switcher-list-3"
|
||
aria-label="Version switcher list"
|
||
>
|
||
Choose version <!-- this text may get changed later by javascript -->
|
||
<span class="caret"></span>
|
||
</button>
|
||
<div id="pst-version-switcher-list-3"
|
||
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
||
role="listbox" aria-labelledby="pst-version-switcher-button-3">
|
||
<!-- dropdown will be populated by javascript on page load -->
|
||
</div>
|
||
</div></div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="sidebar-header-items__end">
|
||
|
||
<div class="navbar-item">
|
||
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="sidebar-primary-items__start sidebar-primary__section">
|
||
<div class="sidebar-primary-item">
|
||
|
||
|
||
|
||
<nav class="bd-docs-nav bd-links"
|
||
aria-label="Table of Contents">
|
||
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
|
||
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../key-features.html">Key Features</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../torch.html">PyTorch Backend</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../release-notes.html">Release Notes</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Installation</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/containers.html">Pre-built release container images on NGC</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.html">Quick Start Recipe for Llama4 Scout 17B on TensorRT-LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.html">Quick Start Recipe for DeepSeek R1 on TensorRT-LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.html">Quick Start Recipe for Llama3.3 70B on TensorRT-LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.html">Quick Start Recipe for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">LLM API</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/index.html">LLM API Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/reference.html">API Reference</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">LLM Examples Introduction</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="simple">
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../examples/customization.html">LLM Common Customizations</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference.html">Generate text</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client.html">Curl Chat Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_completion_client.html">Curl Completion Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/genai_perf_client.html">Genai Perf Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/genai_perf_client_for_multimodal.html">Genai Perf Client For Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.layers.html">Layers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.models.html">Models</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/executor.html">Executor</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/runtime.html">Runtime</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-bench.html">trtllm-bench</a></li>
|
||
|
||
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-build.html">trtllm-build</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Architecture</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html">Model Definition</a></li>
|
||
|
||
|
||
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../architecture/add-model.html">Adding a Model</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Advanced</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/executor.html">Executor API</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/lora.html">Run gpt-2b + LoRA using Executor / cpp runtime</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-management.html">KV Cache Management: Pools, Blocks, and Events</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../advanced/disaggregated-service.html">Disaggregated-Service (Prototype)</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Performance</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-benchmarking.html">Benchmarking</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../performance/performance-tuning-guide/index.html">Performance Tuning Guide</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/benchmarking-default-performance.html">Benchmarking Default Performance</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/useful-build-time-flags.html">Useful Build-Time Flags</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.html">Tuning Max Batch Size and Max Num Tokens</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/deciding-model-sharding-strategy.html">Deciding Model Sharding Strategy</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/fp8-quantization.html">FP8 Quantization</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/useful-runtime-flags.html">Useful Runtime Options</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-analysis.html">Performance Analysis</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/troubleshooting.html">Troubleshooting</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/support-matrix.html">Support Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/precision.html">Numerical Precision</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/ci-overview.html">Continuous Integration Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../reference/dev-containers.html">Using Dev Containers</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT-LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT-LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT-LLM server</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-Gram Speculative Decoding in TensorRT‑LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
|
||
</ul>
|
||
</div>
|
||
</nav></div>
|
||
</div>
|
||
|
||
|
||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
<main id="main-content" class="bd-main" role="main">
|
||
|
||
|
||
<div class="bd-content">
|
||
<div class="bd-article-container">
|
||
|
||
<div class="bd-header-article d-print-none">
|
||
<div class="header-article-items header-article__inner">
|
||
|
||
<div class="header-article-items__start">
|
||
|
||
<div class="header-article-item">
|
||
|
||
<nav aria-label="Breadcrumb" class="d-print-none">
|
||
<ul class="bd-breadcrumbs">
|
||
|
||
<li class="breadcrumb-item breadcrumb-home">
|
||
<a href="../../index.html" class="nav-link" aria-label="Home">
|
||
<i class="fa-solid fa-home"></i>
|
||
</a>
|
||
</li>
|
||
|
||
<li class="breadcrumb-item"><a href="../index.html" class="nav-link">Module code</a></li>
|
||
|
||
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.functional</span></li>
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<div id="searchbox"></div>
|
||
<article class="bd-article">
|
||
|
||
<h1>Source code for tensorrt_llm.functional</h1><div class="highlight"><pre>
|
||
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.</span>
|
||
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Licensed under the Apache License, Version 2.0 (the "License");</span>
|
||
<span class="c1"># you may not use this file except in compliance with the License.</span>
|
||
<span class="c1"># You may obtain a copy of the License at</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
|
||
<span class="c1">#</span>
|
||
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
|
||
<span class="c1"># distributed under the License is distributed on an "AS IS" BASIS,</span>
|
||
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
|
||
<span class="c1"># See the License for the specific language governing permissions and</span>
|
||
<span class="c1"># limitations under the License.</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">weakref</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">OrderedDict</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">enum</span><span class="w"> </span><span class="kn">import</span> <span class="n">IntEnum</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
|
||
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
|
||
|
||
<span class="c1"># isort: off</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
|
||
<span class="c1"># isort: on</span>
|
||
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.</span><span class="w"> </span><span class="kn">import</span> <span class="n">graph_rewriting</span> <span class="k">as</span> <span class="n">gw</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">._common</span><span class="w"> </span><span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span><span class="p">,</span> <span class="n">precision</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">bf16_array</span><span class="p">,</span> <span class="n">bool_array</span><span class="p">,</span>
|
||
<span class="n">dim_resolve_negative</span><span class="p">,</span> <span class="n">dim_to_trt_axes</span><span class="p">,</span> <span class="n">dims_array</span><span class="p">,</span>
|
||
<span class="n">fp16_array</span><span class="p">,</span> <span class="n">fp32_array</span><span class="p">,</span> <span class="n">get_sm_version</span><span class="p">,</span> <span class="n">int32_array</span><span class="p">,</span>
|
||
<span class="n">int64_array</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span>
|
||
<span class="n">trt_dtype_to_np</span><span class="p">,</span> <span class="n">trt_dtype_to_str</span><span class="p">)</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.network</span><span class="w"> </span><span class="kn">import</span> <span class="n">PluginInfo</span><span class="p">,</span> <span class="n">set_np_weight</span><span class="p">,</span> <span class="n">set_plugin_info</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">,</span> <span class="n">current_all_reduce_helper</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantMode</span>
|
||
|
||
|
||
<div class="viewcode-block" id="DimRange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.DimRange">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">DimRange</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.</span>
|
||
<span class="sd"> For example, tensor has 2 dimensions. Then the data members are:</span>
|
||
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
|
||
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
|
||
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
|
||
<span class="sd"> For static dimension, it has min==opt==max, thus the shape param in the ctor can be an integer</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
|
||
<span class="n">names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,</span>
|
||
<span class="sd"> where N is the number of dimensions for a tensor.</span>
|
||
<span class="sd"> When one element is an integer, it means that dimension is static.</span>
|
||
<span class="sd"> Otherwise, when one element is a tuple/list, it means the dimension is dynamic.</span>
|
||
<span class="sd"> The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts</span>
|
||
<span class="sd"> 0 <= min <= opt <= max.</span>
|
||
|
||
<span class="sd"> Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with</span>
|
||
<span class="sd"> min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.</span>
|
||
<span class="sd"> The shape parameter could be:</span>
|
||
<span class="sd"> [16, (1, 8, 32), 8]</span>
|
||
<span class="sd"> It has same semantics of</span>
|
||
<span class="sd"> [(16, 16, 16), (1, 8, 32), (8, 8, 8)]</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">=</span> <span class="n">names</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">shape</span>
|
||
<span class="p">),</span> <span class="s2">"Expecting shape list and name list must have same length, got {shape=}, {name=}"</span>
|
||
<span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o"><=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> \
|
||
<span class="s2">"Each dimension must specify a 3-elements tuple or list in the order of (min,opt,max), got {dim=}"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s1">'Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="si">}</span><span class="s1">'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__value</span><span class="p">:</span> <span class="nb">object</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">__value</span><span class="p">,</span> <span class="n">DimRange</span><span class="p">)</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">dimension_names</span> <span class="ow">and</span> \
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">min</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">opt</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">max</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">str</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="si">=}</span><span class="s2">)"</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Tensor</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The class to represent dense tensors.</span>
|
||
|
||
<span class="sd"> A dense tensor is named, has a shape and contains typed elements. Each</span>
|
||
<span class="sd"> dimension of a tensor can either be static or dynamic. Static dimensions</span>
|
||
<span class="sd"> are known at engine compilation by TensorRT. Dynamic dimensions can take</span>
|
||
<span class="sd"> values determined at runtime. The tensor can be located on the host (CPU)</span>
|
||
<span class="sd"> or the device (GPU).</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dim_range</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">location</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">TensorLocation</span><span class="o">.</span><span class="n">DEVICE</span><span class="p">,</span>
|
||
<span class="n">network</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">trt_tensor</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> name : str</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
|
||
<span class="sd"> dtype : tensorrt.DataType</span>
|
||
<span class="sd"> The type of the elements of the tensor. See the TensorRT</span>
|
||
<span class="sd"> documentation for list of supported data types.</span>
|
||
|
||
<span class="sd"> shape : tensorrt.Dims</span>
|
||
<span class="sd"> The dimensions of the tensor. In TensorRT-LLM, tensors can have</span>
|
||
<span class="sd"> static or dynamic dimensions (it is possible to mix static and</span>
|
||
<span class="sd"> dynamic dimensions). A static dimension is known when the</span>
|
||
<span class="sd"> TensorRT engine is built. A dynamic dimension can be set when</span>
|
||
<span class="sd"> the engine is executed. Use -1 for dynamic dimensions.</span>
|
||
|
||
<span class="sd"> dim_range : OrderedDict</span>
|
||
<span class="sd"> An ordered dictionary (the positions of the elements matter)</span>
|
||
<span class="sd"> that associates a name and a range of values to the dimensions.</span>
|
||
<span class="sd"> For a static dimension, the range must be limited to a single</span>
|
||
<span class="sd"> value. For a dynamic dimension, the range is defined by three</span>
|
||
<span class="sd"> values [min, opt, max] where min and max are, respectively, the</span>
|
||
<span class="sd"> smallest and largest possible values of that dimension. The</span>
|
||
<span class="sd"> opt value is used by TensorRT to optimize the engine for the</span>
|
||
<span class="sd"> most common case.</span>
|
||
|
||
<span class="sd"> Assume there is N optimization profiles, each item dim_range dict is ordered by:</span>
|
||
<span class="sd"> (dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])</span>
|
||
<span class="sd"> or it's following when the dimension is static (can think as min==opt==max):</span>
|
||
<span class="sd"> (static dimension name : [profile 0 value, profile 1 value, ... profile N value])</span>
|
||
<span class="sd"> For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)</span>
|
||
<span class="sd"> And number of keys is equal to number of optimization profiles.</span>
|
||
|
||
<span class="sd"> is_network_input : bool</span>
|
||
<span class="sd"> A boolean indicating if that tensor is an input of the network.</span>
|
||
<span class="sd"> Inputs must be provided by the user to run the engine.</span>
|
||
|
||
<span class="sd"> location : tensorrt.TensorLocation</span>
|
||
<span class="sd"> A flag to indicate where the tensor will be located. It can be</span>
|
||
<span class="sd"> on the host (CPU) or the device (GPU).</span>
|
||
|
||
<span class="sd"> network: Network</span>
|
||
<span class="sd"> A parent Network instance, that helps to fine the users of this tensor.</span>
|
||
|
||
<span class="sd"> trt_tensor: trt.ITensor</span>
|
||
<span class="sd"> Construct with the ITensor instance directly, and no shape profiles are required.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Layout of self.profiles</span>
|
||
<span class="c1"># Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
|
||
<span class="c1"># ...</span>
|
||
<span class="c1"># Opt profile N: dim 0 ... dim M</span>
|
||
|
||
<span class="c1"># So from the dim_range arg to self.profiles conversion, there is a layout transpose</span>
|
||
<span class="c1"># dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span> <span class="o">=</span> <span class="p">[]</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># specially for the graph rewriter</span>
|
||
|
||
<span class="c1"># work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter</span>
|
||
<span class="k">if</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">assert</span> <span class="n">network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">is_network_input</span><span class="p">,</span> <span class="s2">"is_network_input should be False when trt_tensor is not None"</span>
|
||
<span class="k">return</span>
|
||
|
||
<span class="c1"># be cautious here, the weakref is critical to avoid circular referencing before Network and Tensor</span>
|
||
<span class="c1"># using strong reference will likely cause significant peak memory increase, since Network objects</span>
|
||
<span class="c1"># holds the weights data.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">default_net</span><span class="p">())</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_network_input</span> <span class="o">=</span> <span class="n">is_network_input</span>
|
||
<span class="k">if</span> <span class="n">is_network_input</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">dim_range</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim_range</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">dim_range</span>
|
||
<span class="p">)</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Each input tensor shall have at least one dimension, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' found </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">found_profiles</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">len</span><span class="p">(</span><span class="n">ranges</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">x</span> <span class="o">==</span> <span class="n">found_profiles</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">found_profiles</span><span class="p">]</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"Expecting all the dimensions in the dim_range has same number of profiles, tensor '</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">' got </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="n">num_opt_profile</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="k">assert</span> <span class="n">num_opt_profile</span> <span class="o">>=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_opt_profile</span><span class="p">):</span>
|
||
<span class="n">range_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">dimension_names</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ranges</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
|
||
<span class="n">range_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ranges</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="n">dimension_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DimRange</span><span class="p">(</span><span class="n">range_shape</span><span class="p">,</span> <span class="n">dimension_names</span><span class="p">))</span>
|
||
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dim_range</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">network</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_network</span><span class="p">()</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="nd">@name</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the name of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span>
|
||
|
||
<span class="nd">@dtype</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the type of the elements in the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The shape of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the shape of the tensor. See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The physical location of the tensor (on the host or the device).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span>
|
||
|
||
<span class="nd">@location</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">location</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Set the physical location of the tensor (on the host or the device). See __init__.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.mark_output">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mark_output">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">name</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Mark a tensor as a network output.</span>
|
||
|
||
<span class="sd"> When a tensor is marked as an output, its content can be obtained after</span>
|
||
<span class="sd"> the execution of the TensorRT engine. The user is responsible for</span>
|
||
<span class="sd"> allocating buffers to store the output tensors when preparing the</span>
|
||
<span class="sd"> execution of the TensorRT engine.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__add__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__radd__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.add.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__sub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sub.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__mul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__rmul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mul.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.div.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__floordiv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.floordiv.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">floordiv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__mod__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.floordiv.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">modulo</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__lt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.lt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">lt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__gt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.gt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">gt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span><span class="p">:</span>
|
||
<span class="c1"># for graph rewriter</span>
|
||
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="nb">hash</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># for creating the network</span>
|
||
<span class="k">return</span> <span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__ge__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.gt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__gt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__le__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.lt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__lt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.view">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.view.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.flatten">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.flatten">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.flatten.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.permute">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.permute.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.transpose">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.transpose.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mean">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.mean.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.max">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.max.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.abs">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.abs">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.abs.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.sqrt">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.sqrt">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.sqrt.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.squeeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.squeeze">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">squeeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.squeeze.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">squeeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.unsqueeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.unsqueeze">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">unsqueeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.squeeze.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.log">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.log">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.log.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">log</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.cast">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.cast.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.size">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.size">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the shape of the tensor if the dim parameter is None.</span>
|
||
<span class="sd"> Otherwise, returns a size of the dimension indicated by dim. The</span>
|
||
<span class="sd"> behavior is undefined if dim is negative or exceeds the rank of the</span>
|
||
<span class="sd"> tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.rank">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.rank">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rank</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.ndim">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.ndim">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">ndim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.split">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.split.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.select">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.select.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.unbind">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.unbind">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.unbind.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.repeat">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.repeat">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">repeat</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sizes</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> See functional.repeat</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">repeat</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sizes</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_dynamic">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_dynamic">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_dynamic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> If the argument 'dim' is None, that function returns a boolean that</span>
|
||
<span class="sd"> indicates if the tensor contains a dynamic dimension (True) or not</span>
|
||
<span class="sd"> (False). In that case, the first dimension is excluded (as it usually</span>
|
||
<span class="sd"> corresponds to the batch size). If the argument is an integer, that</span>
|
||
<span class="sd"> functions returns a boolean that indicates if the dimension 'dim' is</span>
|
||
<span class="sd"> dynamic (True) or not (False).</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">s</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="c1"># graph writer related functions</span>
|
||
|
||
<div class="viewcode-block" id="Tensor.get_parent">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_parent">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layer that produces this tensor. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.get_users">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_users">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_users</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get the layers that use this tensor as an input. '''</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_users</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.replace_all_uses_with">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.replace_all_uses_with">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">replace_all_uses_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Replace all uses of this tensor as an input to consumer layers</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">is_graph_altered</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">users</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_users</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">user</span> <span class="ow">in</span> <span class="n">users</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">num_inputs</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">user</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">:</span>
|
||
<span class="n">inputs_changed</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="n">user</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">inputs_changed</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"Tensor not found in layer inputs"</span>
|
||
|
||
<span class="c1"># update the FLayerMetadata as well</span>
|
||
<span class="n">flayer</span> <span class="o">=</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
|
||
<span class="n">flayer</span> <span class="ow">and</span> <span class="n">flayer</span><span class="o">.</span><span class="n">replace_input_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Tensor.is_trt_wrapper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_trt_wrapper">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_trt_wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Check if there is a trt.ITensor member inside, which is required for</span>
|
||
<span class="sd"> graph rewriter. In order to differentiate usages, it may be necessary</span>
|
||
<span class="sd"> to have an inheritance hierarchy.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">'trt_tensor'</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_trt_wrapper</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="sa">f</span><span class="s2">"TensorRT-LLM Tensor: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">=}</span><span class="s2">"</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__xor__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Maps to functional.gt or functional.eq.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"self.shape: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">, b.shape: </span><span class="si">{</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"a.shape: </span><span class="si">{</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">, b.shape: </span><span class="si">{</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">op_xor</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_create_tensor</span><span class="p">(</span><span class="n">trt_tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span> <span class="n">producer</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ILayer</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> A helper function to create a TensorRT-LLM Tensor object that encapsulates</span>
|
||
<span class="sd"> the connection between the TensorRT tensor (trt.ITensor) and the layer</span>
|
||
<span class="sd"> (trt.ILayer) that produces it.</span>
|
||
|
||
<span class="sd"> That function is expected to be used as:</span>
|
||
|
||
<span class="sd"> # Insert a new layer in the network using the TensorRT API:</span>
|
||
<span class="sd"> layer = default_trtnet().add_<some_layer>(...)</span>
|
||
<span class="sd"> # Extract the first output of that layer and connect it to the layer.</span>
|
||
<span class="sd"> return _create_tensor(layer.get_output(0), layer)</span>
|
||
|
||
<span class="sd"> That function also sets the precision of the layer/producer to the default</span>
|
||
<span class="sd"> precision of the network.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> trt_tensor : trt.ITensor</span>
|
||
<span class="sd"> The TensorRT tensor to connect to its producer (the layer).</span>
|
||
|
||
<span class="sd"> producer : trt.ILayer</span>
|
||
<span class="sd"> The producer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The TensorRT-LLM tensor (functional.Tensor) that encapsulates the</span>
|
||
<span class="sd"> TensorRT tensor and the layer that produces it. The former is</span>
|
||
<span class="sd"> accessible through the attribute 'trt_tensor' and the latter using the</span>
|
||
<span class="sd"> attribute 'producer'.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="c1"># Set the layer name since this is the only</span>
|
||
<span class="c1"># centralized location to pass the name from</span>
|
||
<span class="c1"># module space to the TRT IR</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_set_layer_name</span><span class="p">(</span><span class="n">producer</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="fm">__len__</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"tensor </span><span class="si">{</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s2"> has an invalid shape"</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
|
||
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">producer</span> <span class="o">=</span> <span class="n">producer</span>
|
||
|
||
<span class="c1"># tb.print_stack(limit=10) # FOR DEBUGGING: filter producer.name if needed</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">SHAPE</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">GATHER</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONCATENATION</span>
|
||
<span class="p">]:</span>
|
||
<span class="n">producer</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span><span class="o">.</span><span class="n">layer_name</span> <span class="o">=</span> <span class="n">producer</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<span class="k">return</span> <span class="n">tensor</span>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plugin_creator</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IPluginCreator</span><span class="p">,</span>
|
||
<span class="n">plugin_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pfc</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plugin_info</span> <span class="o">=</span> <span class="n">PluginInfo</span><span class="p">(</span><span class="n">plugin_creator</span><span class="p">,</span> <span class="n">plugin_name</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">set_plugin_info</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">trt_network</span><span class="p">,</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">plugin_info</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="RotaryScalingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">RotaryScalingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">none</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">linear</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">dynamic</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">longrope</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">llama3</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">6</span>
|
||
|
||
<div class="viewcode-block" id="RotaryScalingType.from_string">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType.from_string">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">RotaryScalingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
|
||
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Unsupported rotary scaling type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">PositionEmbeddingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">learned_absolute</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">rope_gptj</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">rope_gpt_neox</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">long_rope</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">alibi</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">alibi_with_scale</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">relative</span> <span class="o">=</span> <span class="mi">6</span>
|
||
<span class="n">chatglm</span> <span class="o">=</span> <span class="mi">7</span>
|
||
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">9</span>
|
||
<span class="n">deferred</span> <span class="o">=</span> <span class="mi">10</span> <span class="c1"># Apply customized positional embedding by using an external embedder. K will be cached before embedding.</span>
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_rope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_rope">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_rope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_rope</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mrope</span>
|
||
<span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_mrope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_mrope">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_mrope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">mrope</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_alibi">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_alibi">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_alibi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">alibi</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alibi_with_scale</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.is_deferred">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_deferred">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_deferred</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">deferred</span><span class="p">]</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.choices">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.choices">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">choices</span><span class="p">()</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">embedding</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">embedding</span> <span class="ow">in</span> <span class="n">PositionEmbeddingType</span><span class="p">]</span></div>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
|
||
|
||
<div class="viewcode-block" id="PositionEmbeddingType.from_string">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.from_string">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">PositionEmbeddingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
|
||
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Unsupported position embedding type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AttentionMaskType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AttentionMaskType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AttentionMaskType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">causal</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">sliding_window_causal</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">bidirectional</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">bidirectionalglm</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># TODO: merge this mask into bidirectional</span>
|
||
<span class="n">blocksparse</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">custom_mask</span> <span class="o">=</span> <span class="mi">6</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">LayerNormType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">LayerNorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">RmsNorm</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">GroupNorm</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="LayerNormPositionType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormPositionType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">LayerNormPositionType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">pre_layernorm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">post_layernorm</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MLPType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MLPType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MLPType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">MLP</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">GatedMLP</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">FusedGatedMLP</span> <span class="o">=</span> <span class="mi">2</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SliceInputType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.SliceInputType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SliceInputType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">size</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">stride</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">fill_value</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="mi">5</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.activation">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">activation</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an activation function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> act_type : trt.ActivationType</span>
|
||
<span class="sd"> The type of the activation (RELU, TANH, SIGMOID, ...).</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> relu for op=trt.ActivationType.RELU</span>
|
||
<span class="sd"> tanh for op=trt.ActivationType.TANH</span>
|
||
<span class="sd"> sigmoid for op=trt.ActivationType.SIGMOID</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="int_clip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.int_clip">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">int_clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">lower</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">upper</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">lower</span> <span class="o"><=</span> <span class="n">upper</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Lower bound must be less than or equal to upper bound i.e. </span><span class="si">{</span><span class="n">lower</span><span class="si">}</span><span class="s2"> <= </span><span class="si">{</span><span class="n">upper</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">minimum</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">upper</span><span class="p">)</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">maximum</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">lower</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">res</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="clip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.clip">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a CLIP operation that sets the range to [alpha, beta].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> alpha : float</span>
|
||
<span class="sd"> The lower bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The upper bound of the CLIP function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">CLIP</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">relu</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">RELU</span><span class="p">)</span>
|
||
<span class="n">tanh</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">TANH</span><span class="p">)</span>
|
||
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SIGMOID</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="silu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.silu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">silu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SiLU (`x * sigmoid(x)`) operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.swiglu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">silu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="squared_relu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squared_relu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">squared_relu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Squared ReLU operation.</span>
|
||
|
||
<span class="sd"> This function applies ReLU and squares the output.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="nb">pow</span><span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">2.0</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cast">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cast">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a cast operation.</span>
|
||
|
||
<span class="sd"> For an input tensor of type INT8, this function sets the dynamic range of</span>
|
||
<span class="sd"> the input to [-127, 127] for automatic dequantization. For a cast into</span>
|
||
<span class="sd"> INT8, that function sets the dynamic range of the output to [-127, 127] for</span>
|
||
<span class="sd"> automatic quantization.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dtype : str or trt.DataType</span>
|
||
<span class="sd"> The data type of the output tensor after the cast. When 'dtype' is</span>
|
||
<span class="sd"> provided as a string, it must be a name amongst the valid names.</span>
|
||
<span class="sd"> See _str_to_trt_dtype_dict in _utils.py for a list of supported</span>
|
||
<span class="sd"> types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
|
||
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">cvt_dtype</span><span class="p">:</span>
|
||
<span class="c1"># If input type and cast dtype are the same, do nothing</span>
|
||
<span class="k">return</span> <span class="nb">input</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">cvt_dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">'int8'</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="flip">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flip">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">flip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Reverses the order of an n-D tensor along given axis in dims.</span>
|
||
|
||
<span class="sd"> That flip operation maps to a TensorRT ISliceLayer. For the dimensions</span>
|
||
<span class="sd"> listed in dims it copies the elements from the last one to the first one</span>
|
||
<span class="sd"> (from (N-1) down to 0 with a step of -1). For the dimensions not in 'dims',</span>
|
||
<span class="sd"> it copies the elements from the first one to the last one (from 0 to N-1</span>
|
||
<span class="sd"> with a step of 1).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the cast is applied.</span>
|
||
|
||
<span class="sd"> dims : list or tuple</span>
|
||
<span class="sd"> The axes to flip. Negative indices are supported.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="n">ndim</span>
|
||
<span class="k">if</span> <span class="o">-</span><span class="n">ndim</span> <span class="o"><=</span> <span class="n">value</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dims</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
|
||
|
||
<span class="n">start_values</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">stride_values</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">start_values</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(),</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">stride_values</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="interpolate">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.interpolate">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">interpolate</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale_factor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'nearest'</span><span class="p">,</span>
|
||
<span class="n">align_corners</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">recompute_scale_factor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">antialias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="k">assert</span> <span class="mi">2</span> <span class="o"><</span> <span class="n">input_ndim</span> <span class="o"><</span> <span class="mi">6</span><span class="p">,</span> <span class="s2">"Only 3D, 4D and 5D input Tensors supported"</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span>
|
||
<span class="n">scale_factor</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="s2">"Only one of out_shape or scales should be defined"</span>
|
||
|
||
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'bicubic'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">,</span>
|
||
<span class="s1">'nearest-exact'</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'trilinear'</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">5</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"trilinear only supports 5D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"bilinear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"bilinear only supports 4D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"linear"</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"linear only supports 3D tensor"</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_resize</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">if</span> <span class="n">scale_factor</span><span class="p">:</span>
|
||
<span class="n">scale_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">))</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">scale_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">)):</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="p">[</span><span class="n">scale_factor</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_scale</span> <span class="o">=</span> <span class="n">scale_factor</span>
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">updated_scale</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">size_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">size_len</span> <span class="o">==</span> <span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span>
|
||
<span class="k">if</span> <span class="n">size_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">updated_size</span> <span class="o">=</span> <span class="n">size</span>
|
||
|
||
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="o"><</span> <span class="mi">2</span> <span class="k">else</span> <span class="n">updated_size</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">updated_shape</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'nearest'</span><span class="p">,</span> <span class="s1">'nearest-exact'</span><span class="p">]</span> <span class="ow">or</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'linear'</span><span class="p">,</span> <span class="s1">'bilinear'</span><span class="p">,</span> <span class="s1">'trilinear'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">LINEAR</span>
|
||
<span class="k">if</span> <span class="n">align_corners</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ALIGN_CORNERS</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
<span class="c1"># TODO, need to confirm the align_corners effect on bilinear mode.</span>
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">'bilinear'</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'bicubic'</span><span class="p">]:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">CUBIC</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="matmul">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.matmul">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">matmul</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">use_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication.</span>
|
||
|
||
<span class="sd"> That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained</span>
|
||
<span class="sd"> in the TensorRT documentation, it computes the inner product between the</span>
|
||
<span class="sd"> two inputs after applying an optional transposition on the inputs.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> mat2 : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> use_fp32_acc: bool</span>
|
||
<span class="sd"> Set to 'True' if for accuracy reason, this fp16 matmul needs to use</span>
|
||
<span class="sd"> fp32 accumulation. This can be a per model and per matmul decision.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># This option is only supported for fp16, but not bf16 or any other precisions.</span>
|
||
<span class="n">use_fp32_acc</span> <span class="o">=</span> <span class="n">use_fp32_acc</span> <span class="ow">and</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span> <span class="ow">and</span> <span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s1">'float32'</span><span class="p">)</span>
|
||
<span class="n">mat2</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">mat2</span><span class="p">,</span> <span class="s1">'float32'</span><span class="p">)</span>
|
||
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span><span class="p">)</span>
|
||
<span class="n">op0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transa</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">op1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transb</span> \
|
||
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_matrix_multiply</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op0</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op1</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="s2">"float16"</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gemm_swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gemm_swiglu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
|
||
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> bias : Optional[Tensor]</span>
|
||
<span class="sd"> The per-channel bias. The plugin with fp8 dtype does not support bias yet.</span>
|
||
|
||
<span class="sd"> scale_d0 : float</span>
|
||
<span class="sd"> The scale for dequantizing x, used for fp8</span>
|
||
|
||
<span class="sd"> scale_d1 : float</span>
|
||
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
|
||
|
||
<span class="sd"> scale_output : float</span>
|
||
<span class="sd"> The scale for quantizing output, used for fp8</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GemmSwiglu'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span>
|
||
<span class="k">if</span> <span class="n">p_dtype</span> <span class="o">==</span> <span class="s2">"fp8"</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"fp8 gemm_swiglu does not support bias yet"</span>
|
||
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_has_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"has_bias"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="mi">0</span> <span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d0"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d1"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_output"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_has_bias</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">])</span>
|
||
<span class="n">gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"gemm_swiglu"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="c1"># TODO(anchengc) pass nullptr when no bias</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">gemm_swiglu_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">constant</span><span class="p">(</span><span class="n">ndarray</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
|
||
<span class="n">as_dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">as_shape</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a constant layer.</span>
|
||
|
||
<span class="sd"> TensorRT graphs encapsulate constant values in the form of constant layers</span>
|
||
<span class="sd"> (tensorrt.IConstantLayer). That function creates such a layer from a Numpy</span>
|
||
<span class="sd"> array of values. After compilation of the network by TensorRT, those</span>
|
||
<span class="sd"> weights are stored in the serialized TensorRT engine.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> ndarray : numpy.ndarray</span>
|
||
<span class="sd"> The array of values (weights) encapsulated by this constant layer.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">if</span> <span class="n">as_dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">as_dtype</span>
|
||
<span class="n">trt_shape</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span>
|
||
<span class="n">ndarray</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">as_shape</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span><span class="n">as_shape</span><span class="p">)</span>
|
||
<span class="n">trt_count</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">trt_shape</span><span class="p">)):</span>
|
||
<span class="n">trt_count</span> <span class="o">*=</span> <span class="n">trt_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">weights</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">(</span><span class="n">trt_dtype</span><span class="p">,</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">trt_count</span><span class="p">)</span>
|
||
<span class="c1"># Prevent underlying numpy array from going out of scope</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">register_ndarray</span><span class="p">(</span><span class="n">ndarray</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_constant</span><span class="p">(</span><span class="n">trt_shape</span><span class="p">,</span> <span class="n">weights</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">trt_dtype</span><span class="p">)</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="c1"># TODO: remove this WAR after https://nvbugs/4359151 fixed.</span>
|
||
<span class="n">set_np_weight</span><span class="p">(</span><span class="n">default_trtnet</span><span class="p">(),</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">ndarray</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: TensorRT uses sizes of the output dimensions.</span>
|
||
<span class="c1"># DL framework uses ends usually. Will change it to ends.</span>
|
||
<div class="viewcode-block" id="slice">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.slice">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">slice</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">starts</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">sizes</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">strides</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fill_value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to extract a slice from a tensor.</span>
|
||
|
||
<span class="sd"> As described in the TensorRT documentation of the ISliceLayer, the slice</span>
|
||
<span class="sd"> layer has two variants: Static and dynamic.</span>
|
||
|
||
<span class="sd"> For static slicing, this function takes the starts and sizes values in the</span>
|
||
<span class="sd"> different dimensions to slice at layer creation time via a sequence of</span>
|
||
<span class="sd"> integers. For dynamic slicing, it accepts starts and sizes as</span>
|
||
<span class="sd"> tensorrt.ITensor`s.</span>
|
||
|
||
<span class="sd"> The slice layer selects for each dimension a start location from within the</span>
|
||
<span class="sd"> input tensor, and copies elements to the output tensor using a stride of 1</span>
|
||
<span class="sd"> across the input tensor. Start and size tensors must be 1-D int32 shape</span>
|
||
<span class="sd"> tensors if not specified as a sequence of integers.</span>
|
||
|
||
<span class="sd"> As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to</span>
|
||
|
||
<span class="sd"> slice(input, start=[1, 0], size=[1, 2])</span>
|
||
|
||
<span class="sd"> will produce the tensor [[1, 3]] as output. The slice operator when</span>
|
||
<span class="sd"> executed by TensorRT will copy one row (because size[0] == 1) starting from</span>
|
||
<span class="sd"> the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting</span>
|
||
<span class="sd"> from the 1st column (because start[1] == 0).</span>
|
||
|
||
<span class="sd"> In pseudo-code the behavior of that operation can be described as follows</span>
|
||
<span class="sd"> for a 2D tensor (and easily be extended to more dimensions):</span>
|
||
|
||
<span class="sd"> output = Tensor(shape=sizes)</span>
|
||
<span class="sd"> for ii in range(sizes[0]):</span>
|
||
<span class="sd"> for jj in range(sizes[1]):</span>
|
||
<span class="sd"> output[ii][jj] = input[starts[0]+ii][starts[1]+jj]</span>
|
||
|
||
<span class="sd"> Note that it is common in deep-learning frameworks to use ranges</span>
|
||
<span class="sd"> [start:end] for similar operations. It can be emulated by setting the sizes</span>
|
||
<span class="sd"> argument such that in each dimension [start:start+size] == [start:end] i.e.</span>
|
||
<span class="sd"> size = end-start.</span>
|
||
|
||
<span class="sd"> TensorRT supports different slice modes but that function restricts that</span>
|
||
<span class="sd"> choice to `mode == tensorrt.SampleMode.STRICT_BOUNDS`.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the slicing is performed.</span>
|
||
|
||
<span class="sd"> starts : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The starting points, in the input tensor, and each dimension.</span>
|
||
|
||
<span class="sd"> sizes : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The number of elements in each dimension of the sliced tensor (output).</span>
|
||
|
||
<span class="sd"> strides : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The step be taken from start, in input tensor.</span>
|
||
|
||
<span class="sd"> mode : trt.SampleMode</span>
|
||
<span class="sd"> The mode that controls how the slice operation handles out of bounds coordinates.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the slice layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="n">starts</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="n">sizes</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="n">strides</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">or</span> <span class="n">strides</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">trt_strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="n">fill_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="n">fill_value</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">fill_value</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="n">trt_starts</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="n">trt_sizes</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="n">trt_strides</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">starts</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">sizes</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span><span class="o">.</span><span class="n">FILL</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="pad">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.pad">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">pad</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'constant'</span><span class="p">,</span>
|
||
<span class="n">value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a pad layer.</span>
|
||
|
||
<span class="sd"> The padding layer adds zero-padding at the start and end of the input tensor. And the</span>
|
||
<span class="sd"> padding size by which to pad some dimensions of input are described starting from the</span>
|
||
<span class="sd"> last dimension and moving forward.</span>
|
||
|
||
<span class="sd"> `[len(pad) / 2]` dimensions of input will be padded. For example, to pad only the last</span>
|
||
<span class="sd"> dimension of the input tensor, then pad has the form [padding_left, padding_right]; to</span>
|
||
<span class="sd"> pad the last 2 dimensions of the input tensor, then use [padding_left, padding_right,</span>
|
||
<span class="sd"> padding_top, padding_bottom]; to pad the last 3 dimensions, use [padding_left,</span>
|
||
<span class="sd"> padding_right, padding_top, padding_bottom, padding_front, padding_back].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the padding_2d is performed.</span>
|
||
<span class="sd"> pad : sequence of int</span>
|
||
<span class="sd"> An m-elements tuple for padding, where its length m meets the requirement that</span>
|
||
<span class="sd"> m <= 2*input dimensions, and m is even.</span>
|
||
<span class="sd"> mode : str</span>
|
||
<span class="sd"> Only \'constant\' is supported.</span>
|
||
<span class="sd"> value : float</span>
|
||
<span class="sd"> Fill value for 'constant' padding. Default: 0.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"constant"</span><span class="p">,</span> <span class="s2">"Only `'constant'` is supported now."</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="p">(</span>
|
||
<span class="nb">len</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">2</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="p">),</span> <span class="s2">"The length of `pad` should be even and less than 2*input.ndim"</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="n">pad</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
|
||
<span class="k">assert</span> <span class="p">(</span>
|
||
<span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">2</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="p">),</span> <span class="s2">"The length of `pad` should be even and less than 2*input.ndim"</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="n">pad</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s2">"int32"</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"pad type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span><span class="si">}</span><span class="s2"> not supported"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">value</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)),</span>
|
||
<span class="n">pad</span><span class="p">])</span> <span class="c1"># pre-padding the indices</span>
|
||
<span class="n">padding_index</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">padding_index</span><span class="p">[</span><span class="o">-</span><span class="p">(</span><span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># reverse the indices</span>
|
||
<span class="n">pad</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">index</span><span class="o">=</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">padding_index</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)))</span>
|
||
<span class="n">pre_padding</span><span class="p">,</span> <span class="n">post_padding</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="n">chunks</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="p">(</span><span class="n">pre_padding</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="o">*</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">'int32'</span><span class="p">)</span>
|
||
<span class="n">extend_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">pre_padding</span> <span class="o">+</span> <span class="n">post_padding</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
|
||
<span class="n">size</span> <span class="o">=</span> <span class="p">(</span><span class="n">extend_size</span> <span class="o">+</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">'int32'</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">(),</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">(),</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span><span class="o">.</span><span class="n">FILL</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">start</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">size</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">fill_value</span><span class="p">,</span>
|
||
<span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rand">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rand">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">low</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">high</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'float32'</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> shape: Tensor</span>
|
||
<span class="sd"> The shape of the tensor needed to be generated.</span>
|
||
<span class="sd"> low: float</span>
|
||
<span class="sd"> The minimum value (inclusive) of the range used for random.</span>
|
||
<span class="sd"> high: float</span>
|
||
<span class="sd"> The maximum value (inclusive) of the range used for random.</span>
|
||
<span class="sd"> dtype: Union[str, trt.DataType]</span>
|
||
<span class="sd"> The desired data type for the output tensor.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The generated random tensor produced by the fill layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># NOTE: DISABLED FOR NOW UNTIL THE FILL LAYER (RANDOM_UNIFORM) in TRT IS FIXED</span>
|
||
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">"The rand() op is temporarily disabled."</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">low</span><span class="p">))</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">high</span><span class="p">))</span>
|
||
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span> <span class="k">else</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">RANDOM_UNIFORM</span><span class="p">,</span>
|
||
<span class="n">trt_dtype</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">high</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="categorical_sample">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.categorical_sample">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">categorical_sample</span><span class="p">(</span><span class="n">probs</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rand_data</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This is a sampling operation and an equivalent of torch.distributions.Categorical.sample()</span>
|
||
<span class="sd"> i.e. given a probability distribution tensor, it samples an index of that tensor.</span>
|
||
<span class="sd"> See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample</span>
|
||
<span class="sd"> NOTE: This assumes that the given probabilities are **not** normalized.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> probs: Tensor</span>
|
||
<span class="sd"> A 1-D floating point tensor representing the probability distributions.</span>
|
||
<span class="sd"> rand_data: Tensor (optional)</span>
|
||
<span class="sd"> A random tensor of same shape as `probs` tensor.</span>
|
||
<span class="sd"> If not provided, this function will add a rand() op to generate it and use for sampling.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor containing a single index of the `probs` tensor representing the sample.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">probs</span> <span class="o">=</span> <span class="n">probs</span> <span class="o">/</span> <span class="nb">sum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">rand_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">assert</span> <span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">rand_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="n">rand_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">rand_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">rand</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">rand_shape</span> <span class="o">==</span> <span class="n">shape</span><span class="p">(</span><span class="n">rand_data</span><span class="p">)</span>
|
||
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">rand_data</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">))</span>
|
||
<span class="n">cum_probs</span> <span class="o">=</span> <span class="n">cumsum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">cmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">cum_probs</span> <span class="o">>=</span> <span class="n">rand_data</span><span class="p">,</span> <span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">samples</span> <span class="o">=</span> <span class="n">argmax</span><span class="p">(</span><span class="n">cmp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">samples</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="Conditional">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Conditional</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to conditionally execute two code paths/subgraphs.</span>
|
||
|
||
<span class="sd"> Usage:</span>
|
||
<span class="sd"> 1. conditional = Conditional(condition)</span>
|
||
<span class="sd"> 2. input_1_ = conditional.add_input(input_1)</span>
|
||
<span class="sd"> ...</span>
|
||
<span class="sd"> input_n_ = conditional.add_input(input_n)</span>
|
||
<span class="sd"> 3. Construct the graph to get true_output_value and false_output_value using input_1_, ..., input_n_</span>
|
||
<span class="sd"> 4. output = conditional.add_output(true_output_value, false_output_value)</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_if_conditional</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">condition</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">view</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="p">[])</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">set_condition</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="Conditional.add_input">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_input">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="n">in_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_input</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">in_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">in_node</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="Conditional.add_output">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_output">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">add_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">true_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">false_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="n">out_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_output</span><span class="p">(</span><span class="n">true_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">false_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">out_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">out_node</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="c1"># TODO: support step.</span>
|
||
<div class="viewcode-block" id="arange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.arange">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">arange</span><span class="p">(</span><span class="n">start</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="n">end</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to fill a 1D tensor.</span>
|
||
|
||
<span class="sd"> The tensor is filled with the values between start and end with a step of 1</span>
|
||
<span class="sd"> between the different elements. In pseudo-code, it corresponds to a tensor</span>
|
||
<span class="sd"> populated with the values:</span>
|
||
|
||
<span class="sd"> output = Tensor([dtype(ii) for ii in range(start, end, 1)])</span>
|
||
|
||
<span class="sd"> For example, a call to arange(3, 6, 'int32') will add an operation to the</span>
|
||
<span class="sd"> TensorRT graph that will produce [3, 4, 5] when executed. The call to</span>
|
||
<span class="sd"> arange(2, 5, 'float32') will add a layer to generate [2.0, 3.0, 4.0].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.IFillLayer in</span>
|
||
<span class="sd"> trt.FillOperation.LINSPACE mode.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> start : Union[Tensor, int]</span>
|
||
<span class="sd"> The starting point of the range.</span>
|
||
|
||
<span class="sd"> end : Union[Tensor, int]</span>
|
||
<span class="sd"> The end point of the range.</span>
|
||
|
||
<span class="sd"> dtype : str</span>
|
||
<span class="sd"> The type of the elements. See _str_to_trt_dtype_dict in _utils.py</span>
|
||
<span class="sd"> for a list of supported types and type names.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the fill layer. It is a 1D tensor containing</span>
|
||
<span class="sd"> `end-start` elements of type `dtype`.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">res_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="n">array_func</span> <span class="o">=</span> <span class="n">int32_array</span> <span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="k">else</span> <span class="n">int64_array</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
|
||
<span class="k">assert</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
|
||
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="c1"># end == trt.int64</span>
|
||
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span> <span class="c1"># start == trt.int64 and end == trt.int32</span>
|
||
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
|
||
|
||
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"start type (</span><span class="si">{</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) != end type (</span><span class="si">{</span><span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">)"</span>
|
||
<span class="n">step</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
|
||
<span class="n">num</span> <span class="o">=</span> <span class="n">num</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">LINSPACE</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 0</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">step</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">res_dtype</span><span class="p">:</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">expand</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">expand_shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand a tensor.</span>
|
||
|
||
<span class="sd"> The operation expands the input tensor in the singleton dimensions to the</span>
|
||
<span class="sd"> size indicated by the corresponding dimension in the `expand_shape` tensor.</span>
|
||
<span class="sd"> In other words, given an input tensor with dimensions of size 1, those</span>
|
||
<span class="sd"> dimensions will be expanded to the size in `expand_shape`.</span>
|
||
|
||
<span class="sd"> For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of</span>
|
||
<span class="sd"> shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).</span>
|
||
|
||
<span class="sd"> The expansion may either replicate the values or be mapped to a view with a</span>
|
||
<span class="sd"> stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of</span>
|
||
<span class="sd"> shape [1, 2],</span>
|
||
|
||
<span class="sd"> expand([[3, 2]], [2, 2])</span>
|
||
|
||
<span class="sd"> can be used to expand the input to [[3, 2], [3, 2]].</span>
|
||
|
||
<span class="sd"> This operation is implemented using a tensorrt.ISliceLayer. The current</span>
|
||
<span class="sd"> implementation does not verify that non singleton dimensions are not</span>
|
||
<span class="sd"> shrunk. In other words, for an input of shape [4, 1, 2],</span>
|
||
|
||
<span class="sd"> expand(input, [3, 2, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [3, 2, 2]. That behavior is subject to</span>
|
||
<span class="sd"> change in the future.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> expand_shape : Tensor</span>
|
||
<span class="sd"> The new shape of the expanded tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the expand layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span>
|
||
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span> <span class="c1"># unused dummy value</span>
|
||
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># The stride is either:</span>
|
||
<span class="c1"># 0 for dimensions of size 1 (i.e. shape(input, i) - 1 == 1 - 1 == 0) or,</span>
|
||
<span class="c1"># 1 for dimensions of size > 1 since minimum(value >= 1, 1) == 1.</span>
|
||
<span class="n">stride_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">minimum</span><span class="p">((</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
|
||
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">expand_shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="einsum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.einsum">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">einsum</span><span class="p">(</span><span class="n">einsum_eq</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an Einsum operation.</span>
|
||
|
||
<span class="sd"> That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT</span>
|
||
<span class="sd"> documentation, this layer implements a summation over the elements of the</span>
|
||
<span class="sd"> inputs along dimensions specified by the equation parameter, based on the</span>
|
||
<span class="sd"> Einstein summation convention. The layer can have one or more inputs of</span>
|
||
<span class="sd"> rank >= 0. All the inputs must be of same data type. This layer supports</span>
|
||
<span class="sd"> all TensorRT data types except bool. There is one output tensor of the same</span>
|
||
<span class="sd"> type as the input tensors. The shape of output tensor is determined by the</span>
|
||
<span class="sd"> equation.</span>
|
||
|
||
<span class="sd"> The equation specifies ASCII lower-case letters for each dimension in the</span>
|
||
<span class="sd"> inputs in the same order as the dimensions, separated by comma for each</span>
|
||
<span class="sd"> input. The dimensions labeled with the same subscript must match or be</span>
|
||
<span class="sd"> able to be broadcasted. Repeated subscript labels in one input take the diagonal.</span>
|
||
<span class="sd"> Repeating a label across multiple inputs means that those axes will be</span>
|
||
<span class="sd"> multiplied. Omitting a label from the output means values along those axes</span>
|
||
<span class="sd"> will be summed. In implicit mode, the indices which appear once in the</span>
|
||
<span class="sd"> expression will be part of the output in increasing alphabetical order. In</span>
|
||
<span class="sd"> explicit mode, the output can be controlled by specifying output subscript</span>
|
||
<span class="sd"> labels by adding an arrow (‘->’) followed by subscripts for the output. For</span>
|
||
<span class="sd"> example, “ij,jk->ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used</span>
|
||
<span class="sd"> in place of subscripts to broadcast the dimensions. See the TensorRT</span>
|
||
<span class="sd"> Developer Guide for more details on equation syntax.</span>
|
||
|
||
<span class="sd"> Many common operations can be expressed using the Einsum equation. For</span>
|
||
<span class="sd"> example:</span>
|
||
<span class="sd"> Matrix Transpose: ij->ji</span>
|
||
<span class="sd"> Sum: ij-> Matrix-Matrix</span>
|
||
<span class="sd"> Multiplication: ik,kj->ij</span>
|
||
<span class="sd"> Dot Product: i,i-></span>
|
||
<span class="sd"> Matrix-Vector Multiplication: ik,k->i</span>
|
||
<span class="sd"> Batch Matrix Multiplication: ijk,ikl->ijl</span>
|
||
<span class="sd"> Batch Diagonal: …ii->…i</span>
|
||
|
||
<span class="sd"> Note that TensorRT does not support ellipsis or diagonal operations so,</span>
|
||
<span class="sd"> neither, does TensorRT-LLM.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> einsum_eq : str</span>
|
||
<span class="sd"> The Einsum equation.</span>
|
||
|
||
<span class="sd"> inputs: Sequence[Tensor]</span>
|
||
<span class="sd"> The sequence of inputs consumed by the Einsum operation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the Einsum operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_einsum</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span>
|
||
<span class="n">einsum_eq</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="permute">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.permute">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to permute the dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> The dimensions of the input tensor are permuted according to the sequence</span>
|
||
<span class="sd"> of dimensions in 'dims'. That operation maps to tensorrt.IShuffleLayer where</span>
|
||
<span class="sd"> the second transposition is described by the indices in 'dims'.</span>
|
||
|
||
<span class="sd"> Given a tensor of rank N, the result of the permutation is a tensor of rank</span>
|
||
<span class="sd"> N in which the i-th input dimension maps to the dims[i]-th dimension.</span>
|
||
|
||
<span class="sd"> For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting</span>
|
||
<span class="sd"> the rows and columns.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to permute.</span>
|
||
|
||
<span class="sd"> dims : Sequence[int]</span>
|
||
<span class="sd"> The description of the permutation.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dims</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="n">dims</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="transpose">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.transpose">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">transpose</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim0</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to transpose two dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> That operation produces a tensor in which the dimensions 'dim0' and 'dim1'</span>
|
||
<span class="sd"> are permuted. The other dimensions, if the rank of the tensor is greater</span>
|
||
<span class="sd"> than 2, remain untouched.</span>
|
||
|
||
<span class="sd"> That function is a helper built on the 'functional.permute' function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> dim0 : int</span>
|
||
<span class="sd"> The first dimension to transpose.</span>
|
||
|
||
<span class="sd"> dim1 : int</span>
|
||
<span class="sd"> The second dimension to transpose.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the permutation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">permutation</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim0</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim1</span>
|
||
<span class="n">permutation</span><span class="p">[</span><span class="n">dim1</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim0</span>
|
||
|
||
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">permutation</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="view">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.view">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">view</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a view of a tensor.</span>
|
||
|
||
<span class="sd"> That operation adds a tensorrt.IShuffleLayer to the network. If the 'shape'</span>
|
||
<span class="sd"> parameter is a Tensor, that view is dynamic. Otherwise, it is a static</span>
|
||
<span class="sd"> view.</span>
|
||
|
||
<span class="sd"> Note that TensorRT limits the number of inferred dimensions to 1. It means</span>
|
||
<span class="sd"> that the shape sequence or tensor cannot contain more than one -1. This</span>
|
||
<span class="sd"> function enforces that constraint and will assert if it is not respected.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to transpose.</span>
|
||
|
||
<span class="sd"> shape : Union[Tensor, Sequence[int]]</span>
|
||
<span class="sd"> The shape of the new tensor.</span>
|
||
|
||
<span class="sd"> zero_is_placeholder : bool</span>
|
||
<span class="sd"> When that parameter is True, the 0s in 'shape' are replaced by the</span>
|
||
<span class="sd"> sizes of the corresponding dimensions from the 'input'. Otherwise,</span>
|
||
<span class="sd"> the dimensions corresponding to 0s are shrunk.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the view/shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># TensorRT demands that at most one dimension is permitted to be specified as -1</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">inferred_dim_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">inferred_dim_list</span><span class="p">)</span> <span class="o"><=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">zero_is_placeholder</span> <span class="o">=</span> <span class="n">zero_is_placeholder</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
|
||
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">reshape_dims</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"</span><span class="si">%s</span><span class="s2"> is not supported"</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="flatten">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flatten">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">flatten</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Flattens input by reshaping it into a one-dimensional tensor.</span>
|
||
|
||
<span class="sd"> If start_dim or end_dim are passed, only dimensions starting with start_dim and</span>
|
||
<span class="sd"> ending with end_dim are flattened. The order of elements in input is unchanged.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to flatten.</span>
|
||
|
||
<span class="sd"> start_dim : int</span>
|
||
<span class="sd"> The first dim to flatten.</span>
|
||
|
||
<span class="sd"> end_dim : int</span>
|
||
<span class="sd"> The last dim to flatten.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the flatten layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="n">shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">start_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="n">start_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="k">if</span> <span class="n">end_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="n">end_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">):</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="n">end_dim</span> <span class="o">-</span> <span class="n">start_dim</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">flat_dim</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">flat_dim</span> <span class="o">*=</span> <span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">flat_dim</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ndim</span><span class="p">):</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">shape_cast_dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the tensor shape with singleton dimensions.</span>
|
||
|
||
<span class="sd"> That function adds a tensorrt.IShuffleLayer to the network. Given an 'input'</span>
|
||
<span class="sd"> of rank N and a sequence of M dimensions, the output tensor produced by</span>
|
||
<span class="sd"> this operation (when executed by TensorRT) will have a rank of N+M. Singleton</span>
|
||
<span class="sd"> dimensions will be inserted at the different positions in 'dim'.</span>
|
||
|
||
<span class="sd"> The pseudo-code for that operation is:</span>
|
||
|
||
<span class="sd"> new_shape, ii = [], 0</span>
|
||
<span class="sd"> for jj in range(input.rank() + len(dim)):</span>
|
||
<span class="sd"> new_shape.append(1 if jj in dims else input.shape[ii++])</span>
|
||
|
||
<span class="sd"> For example, for a tensor of shape [3, 4, 1, 5]</span>
|
||
|
||
<span class="sd"> expand_dims(input, [0, 2])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [1, 3, 1, 4, 1, 5].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand.</span>
|
||
|
||
<span class="sd"> dim : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> The positions in the output tensor where to insert singleton</span>
|
||
<span class="sd"> dimensions.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
|
||
|
||
<span class="n">out_ndim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="o">=</span><span class="n">shape_cast_dtype</span><span class="p">)</span>
|
||
<span class="n">out_shapes</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_ndim</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gather</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
|
||
<span class="n">j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shapes</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># NOTE: Jointly added with Apple</span>
|
||
<div class="viewcode-block" id="squeeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squeeze">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">squeeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to remove singleton dimensions of a tensor.</span>
|
||
|
||
<span class="sd"> This functions creates an operation that removes singleton dimension</span>
|
||
<span class="sd"> (dimension of size 1) at positions 'dim' in the input tensor. It works with</span>
|
||
<span class="sd"> negative values for the 'dim'.</span>
|
||
|
||
<span class="sd"> For example, for a tensor 'input' of shape [1, 4, 1, 4]:</span>
|
||
|
||
<span class="sd"> squeeze(input, 0) will produce an output of shape [4, 1, 4],</span>
|
||
<span class="sd"> squeeze(input, 2) will produce an output of shape [1, 4, 4],</span>
|
||
<span class="sd"> squeeze(input, [0, 2]) will produce an output of shape [4, 4],</span>
|
||
<span class="sd"> squeeze(input, [-2]) will produce an output of shape [1, 4, 4],</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor for which the singleton dimensions will be removed.</span>
|
||
|
||
<span class="sd"> dim : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> The index of the singleton dimensions in the input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">s</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="p">[]</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="n">zero_is_placeholder</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="nb">input</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unsqueeze">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unsqueeze">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to insert a singleton dimension to a tensor.</span>
|
||
|
||
<span class="sd"> That functions creates an operation that insert a singleton dimension</span>
|
||
<span class="sd"> (dimension of size 1) at position 'axis' in the output tensor. It works with</span>
|
||
<span class="sd"> negative values for the 'axis'.</span>
|
||
|
||
<span class="sd"> For example, for a tensor 'input' of shape [4, 4]:</span>
|
||
|
||
<span class="sd"> unsqueeze(input, 0) will produce an output of shape [1, 4, 4],</span>
|
||
<span class="sd"> unsqueeze(input, 1) will produce an output of shape [4, 1, 4],</span>
|
||
<span class="sd"> unsqueeze(input, -1) will produce an output of shape [4, 4, 1],</span>
|
||
<span class="sd"> unsqueeze(input, -2) will produce an output of shape [4, 1, 4],</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to expand with a singleton dimension.</span>
|
||
|
||
<span class="sd"> axis : int</span>
|
||
<span class="sd"> The index of the singleton dimension in the output tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">axis</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="stack">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.stack">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">stack</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to contact input tensors along a new dimension.</span>
|
||
|
||
<span class="sd"> The function creates an operation that creates a new dim for all the</span>
|
||
<span class="sd"> input tensors and then concatenates them along that new dim.</span>
|
||
<span class="sd">.</span>
|
||
|
||
<span class="sd"> All the tensors in 'inputs' must have the same shape.</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> assert all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
|
||
|
||
<span class="sd"> The shape of the output tensor is defined as:</span>
|
||
|
||
<span class="sd"> output.rank() = inputs[0].rank() + 1</span>
|
||
|
||
<span class="sd"> output.shape[dim] = len(inputs)</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> if ii < dim:</span>
|
||
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
|
||
<span class="sd"> else:</span>
|
||
<span class="sd"> output.shape[ii+1] = inputs[0].shape[ii]</span>
|
||
|
||
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
|
||
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
|
||
|
||
<span class="sd"> stack(inputs, 0)</span>
|
||
|
||
<span class="sd"> will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [2, 2, 2] and</span>
|
||
|
||
<span class="sd"> stack(inputs, 1)</span>
|
||
|
||
<span class="sd"> will produce [[[0, 1], [4, 5]], [[2, 3], [6, 7]]] of shape [2, 2, 2].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Sequence[Tensor]</span>
|
||
<span class="sd"> The sequence of tensors to stack.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which the stack is performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the input tensors stacked along a new dimension.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">concat</span><span class="p">([</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span> <span class="k">for</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_dims_like">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims_like">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to expand the first tensor to the same rank as the second</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> That function takes a first tensor. It also accepts an integer or a float,</span>
|
||
<span class="sd"> in which case it creates a constant tensor from it. In both cases, the rank</span>
|
||
<span class="sd"> of that first tensor is compared to the rank of the second tensor. If they</span>
|
||
<span class="sd"> are of the same rank, the first tensor is returned. Otherwise, the first</span>
|
||
<span class="sd"> tensor is expanded on the left to match the rank of the second tensor.</span>
|
||
|
||
<span class="sd"> Note that the shapes do not have to match, only the rank is considered in</span>
|
||
<span class="sd"> that function.</span>
|
||
|
||
<span class="sd"> For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the</span>
|
||
<span class="sd"> first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first tensor to expand. When a scalar value is provided as a</span>
|
||
<span class="sd"> parameter, that function first creates a tensor before expanding it</span>
|
||
<span class="sd"> (if needed).</span>
|
||
|
||
<span class="sd"> right : Tensor</span>
|
||
<span class="sd"> The reference tensor to match.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the shuffle layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
|
||
<span class="n">left_ndim</span> <span class="o">=</span> <span class="n">left</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">right_ndim</span> <span class="o">=</span> <span class="n">right</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">right_ndim</span> <span class="o">></span> <span class="n">left_ndim</span><span class="p">:</span>
|
||
<span class="n">new_ndim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">right_ndim</span> <span class="o">-</span> <span class="n">left_ndim</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">new_ndim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">left</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># If dim is None, return a 1-D TensorRT-LLM tensor of the size</span>
|
||
<span class="c1"># If dim is not None, return a 0-D TensorRT-LLM tensor of the dimension size</span>
|
||
<div class="viewcode-block" id="shape">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.shape">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cast_to_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">clip_before_cast</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to create a shape tensor.</span>
|
||
|
||
<span class="sd"> The shape tensor can either be the shape of the input tensor when the</span>
|
||
<span class="sd"> parameter dim is None or a scalar (tensor of rank 0) that corresponds to</span>
|
||
<span class="sd"> the size of dim-th dimension.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor from which we want to extract the shape or the</span>
|
||
<span class="sd"> size in one dimension.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension from which to extract the size. If it is None, the</span>
|
||
<span class="sd"> entire shape of the input tensor is returned.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the shape of the input tensor (if 'dim' is None)</span>
|
||
<span class="sd"> or the size in the dimension 'dim' of the input tensor. If 'dim' is</span>
|
||
<span class="sd"> 'None', that tensor has the same rank as the input tensor, otherwise</span>
|
||
<span class="sd"> its rank is 0.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shape</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">cast_to_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">clip_before_cast</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span><span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="s1">'int32'</span>
|
||
<span class="ow">or</span> <span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">clip_before_cast</span>
|
||
<span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"This parameter only expects a tuple of 2 integers (lower, upper) but got </span><span class="si">{</span><span class="n">clip_before_cast</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">int_clip</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">res</span>
|
||
|
||
<span class="k">return</span> <span class="n">gather</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gather</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to gather elements from a tensor.</span>
|
||
|
||
<span class="sd"> That function implements the GatherElements operator from the ONNX</span>
|
||
<span class="sd"> specification as described in</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements</span>
|
||
|
||
<span class="sd"> The input and indices arguments must have the same rank >= 1. The operation</span>
|
||
<span class="sd"> will produce a tensor with the same shape as the indices tensor. The axis</span>
|
||
<span class="sd"> is the dimension to gather on.</span>
|
||
|
||
<span class="sd"> As shown in the ONNX description, for a 3D tensor, the output is:</span>
|
||
|
||
<span class="sd"> out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,</span>
|
||
<span class="sd"> out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.</span>
|
||
|
||
<span class="sd"> For example,</span>
|
||
|
||
<span class="sd"> gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])</span>
|
||
|
||
<span class="sd"> will produce [[5, 2], [4, 3]].</span>
|
||
|
||
<span class="sd"> gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])</span>
|
||
|
||
<span class="sd"> will produce [[2], [4]]. See the ONNX documentation for more examples.</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to gather elements from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to gather on.</span>
|
||
|
||
<span class="sd"> indices : Union[Tensor, int]</span>
|
||
<span class="sd"> The positions in the 'dim' dimension to gather from.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the gathered elements. It has the same shape as</span>
|
||
<span class="sd"> the indices tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">indices</span><span class="p">]))</span>
|
||
|
||
<span class="c1"># The input and indices tensors must have the same rank.</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">indices</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.select">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select a slice of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> index-th slice of elements in the dimension 'dim' to create a new tensor.</span>
|
||
<span class="sd"> The output tensor has a shape in which the input dimension 'dim' is</span>
|
||
<span class="sd"> removed.</span>
|
||
|
||
<span class="sd"> The 'index' can either be an integer or a 1D tensor containing a single</span>
|
||
<span class="sd"> element.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> select(input, 0, 1)</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [3] that contains the [2, 1, 2].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' is removed.</span>
|
||
<span class="sd"> It means that for a tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> select(input, 2, 4)</span>
|
||
|
||
<span class="sd"> will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)</span>
|
||
<span class="sd"> and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Union[Tensor, int]</span>
|
||
<span class="sd"> The index of the slice in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slice.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">index</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">index</span><span class="p">]))</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
|
||
<span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="index_select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.index_select">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">index_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select slices of elements from a tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
|
||
<span class="sd"> slices of elements in the dimension 'dim' at the indices listed in 'index'</span>
|
||
<span class="sd"> to create a new tensor. The output tensor has the same rank as the input</span>
|
||
<span class="sd"> tensor.</span>
|
||
|
||
<span class="sd"> The 'index' is a tensor of rank 1.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 0, [0, 1])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].</span>
|
||
|
||
<span class="sd"> Regarding the shape of the output tensor, the dimension 'dim' has the same</span>
|
||
<span class="sd"> size as the 'index' tensor. It means that for a input tensor of shape [4, 2, 6, 3],</span>
|
||
|
||
<span class="sd"> index_select(input, 2, [1, 4])</span>
|
||
|
||
<span class="sd"> will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension</span>
|
||
<span class="sd"> (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd</span>
|
||
<span class="sd"> dimension is shrunk to 2).</span>
|
||
|
||
<span class="sd"> Note that this operation can also be used to expand a tensor in the 'dim'</span>
|
||
<span class="sd"> dimension, for example, on input [[0, 1], [2, 3]],</span>
|
||
|
||
<span class="sd"> index_select(input, 1, [0, 0, 0])</span>
|
||
|
||
<span class="sd"> will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].</span>
|
||
|
||
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to select from.</span>
|
||
|
||
<span class="sd"> index : Tensor</span>
|
||
<span class="sd"> The indices of the slices in the 'dim' dimension to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the selected slices.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<span class="c1"># NOTE: Jointly added with Apple</span>
|
||
<div class="viewcode-block" id="scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">updates</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This operation adds a layer that creates an output tensor by element-wise</span>
|
||
<span class="sd"> copying values from the input tensor and then updating values by the given</span>
|
||
<span class="sd"> `indices` and `updates` tensors.</span>
|
||
<span class="sd"> For a 2D input tensor, it first copies the input to output,</span>
|
||
<span class="sd"> then updates the output tensor like the following for each entry in `updates`:</span>
|
||
<span class="sd"> output[indices[i][j]][j] = updates[i][j] if dim=0</span>
|
||
<span class="sd"> output[i][indices[i][j]] = updates[i][j] if dim=1</span>
|
||
<span class="sd"> If the `input` tensor is [[1, 2, 3], [4, 5, 6]],</span>
|
||
<span class="sd"> the indices tensor is [[1, 2], [0, 1]],</span>
|
||
<span class="sd"> the updates tensor is [[-1, -2], [-3, -4]], and dim=1</span>
|
||
<span class="sd"> the output tensor will be [[1, -1, -2], [-3, -4, 6]].</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input data that needs to be updated.</span>
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The axis on which the scatter is to be performed.</span>
|
||
<span class="sd"> indices: Tensor</span>
|
||
<span class="sd"> An integer tensor of the same rank as input that indicates the positions to be updated.</span>
|
||
<span class="sd"> updates: Tensor</span>
|
||
<span class="sd"> A data tensor of same shape as the `indices` tensor that contains the update values.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor created by the element-wise scatter layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">updates</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather_nd">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_nd">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gather_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch_dims</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Adds a layer that performs a gather with some element-wise dimensions.</span>
|
||
<span class="sd"> See: https://onnx.ai/onnx/operators/onnx__GatherND.html</span>
|
||
<span class="sd"> The gather is performed on dim=batch_dims.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The tensor on which the gather operation is performed.</span>
|
||
<span class="sd"> indices: Tensor</span>
|
||
<span class="sd"> The tensor that indicates which entries to be gathered.</span>
|
||
<span class="sd"> batch_dims: int</span>
|
||
<span class="sd"> The number of first dimensions that should be skipped before gather starts.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor created by the gather layer with GatherMode.ND.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="n">gather_layer</span><span class="o">.</span><span class="n">num_elementwise_dims</span> <span class="o">=</span> <span class="n">batch_dims</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="nonzero">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.nonzero">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">nonzero</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Adds a layer that finds the indices of non-zero values of the input tensor.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input tensor for which we need to find the indices of non-zero values.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor of shape [D, C] where D is the number of dimensions of `input` and</span>
|
||
<span class="sd"> C is the number of non-zero values in it.</span>
|
||
<span class="sd"> Each column of this 2D tensor represents the index tuple for each non-zero value.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">non_zero_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="masked_select">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_select">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">masked_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to select elements from a tensor according to a boolean</span>
|
||
<span class="sd"> mask tensor.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that selects</span>
|
||
<span class="sd"> elements at the indices indicated by the boolean mask tensor to create</span>
|
||
<span class="sd"> a new tensor. The output tensor is a 1-D tensor.</span>
|
||
|
||
<span class="sd"> The input tensor must have rank >= 1. The shapes of the input tensor and</span>
|
||
<span class="sd"> the mask tensor don’t need to match, but they must be able to be broadcasted.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> masked_select(input, [[True, False, True], [False, True, False], [True, False, True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [[True], [False], [True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [[False, False, True]])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [3] that contains the [5, 2, 1].</span>
|
||
|
||
<span class="sd"> masked_select(input, [False])</span>
|
||
|
||
<span class="sd"> will create a tensor of shape [0] which is empty.</span>
|
||
|
||
<span class="sd"> That operation is implemented by NonZero, Shuffle and GatherV2 layers</span>
|
||
<span class="sd"> in TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to select from.</span>
|
||
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The 1-D tensor containing the selected elements.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cumsum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cumsum">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">cumsum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to calculate inclusive cumulative sum of elements of</span>
|
||
<span class="sd"> a tensor in a given dimension.</span>
|
||
|
||
<span class="sd"> Given an input tensor, that function creates an operation that calculates</span>
|
||
<span class="sd"> inclusive cumulative sum of elements in the dimension 'dim' to create</span>
|
||
<span class="sd"> a new tensor. The output tensor has the same shape as the input tensor.</span>
|
||
|
||
<span class="sd"> The input tensor must have rank >= 1. The 'dim' must be valid, and negative</span>
|
||
<span class="sd"> value is supported.</span>
|
||
|
||
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
|
||
<span class="sd"> [3, 3],</span>
|
||
|
||
<span class="sd"> cumsum(input, 0)</span>
|
||
|
||
<span class="sd"> will produce [[4, 2, 5], [6, 3, 7], [10, 10, 8]].</span>
|
||
|
||
<span class="sd"> cumsum(input, 1)</span>
|
||
|
||
<span class="sd"> will produce [[4, 6, 11], [2, 3, 5], [4, 11, 12]].</span>
|
||
|
||
<span class="sd"> That operation is implemented by TensorRT ILoopLayer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor to calculate the inclusive cumulative sum.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension to calculate the inclusive cumulative sum. Negative</span>
|
||
<span class="sd"> value is supported.</span>
|
||
|
||
<span class="sd"> prefer_plugin : bool</span>
|
||
<span class="sd"> Whether to use the cumsumLastDim plugin if dim is last dim.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the inclusive cumulative sum of input.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="k">assert</span> <span class="n">dim</span> <span class="o"><</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">>=</span> <span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">(</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"dim should be in [</span><span class="si">{</span><span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">) when input have rank </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">prefer_plugin</span><span class="p">:</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
|
||
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
|
||
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
|
||
<span class="n">cumsum_last_dim_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span><span class="s1">'CumsumLastDim'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">cumsum_last_dim_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">input_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"input_length"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">input_length</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">cumsum_last_dim_plug</span> <span class="o">=</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"cumsum_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">cumsum_last_dim_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="p">,</span>
|
||
<span class="s2">"cumsum_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># credit to Apple</span>
|
||
<span class="n">reduction_length</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">reduction_range</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="s1">'int64'</span><span class="p">,</span>
|
||
<span class="n">to_array</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
|
||
<span class="n">reduction_length</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="s1">'int64'</span><span class="p">)</span>
|
||
<span class="n">lower_triangle</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="o"><=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">lower_triangle</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">slice_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">slice_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">))])</span>
|
||
<span class="n">slice_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span>
|
||
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span> <span class="n">slice_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">loop_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_loop</span><span class="p">()</span>
|
||
<span class="n">trip_limit</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="n">loop_layer</span><span class="o">.</span><span class="n">add_trip_limit</span><span class="p">(</span><span class="n">trip_limit</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TripLimit</span><span class="o">.</span><span class="n">COUNT</span><span class="p">)</span>
|
||
|
||
<span class="n">iterator_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_iterator</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">cur_slice</span> <span class="o">=</span> <span class="n">iterator_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">running_sum_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_recurrence</span><span class="p">(</span><span class="n">zero_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">running_sum</span> <span class="o">=</span> <span class="n">running_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">cur_sum_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span>
|
||
<span class="n">cur_slice</span><span class="p">,</span> <span class="n">running_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
|
||
<span class="n">cur_sum</span> <span class="o">=</span> <span class="n">cur_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">running_sum_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">cur_sum</span><span class="p">)</span>
|
||
|
||
<span class="n">loop_output_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_loop_output</span><span class="p">(</span>
|
||
<span class="n">cur_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LoopOutput</span><span class="o">.</span><span class="n">CONCATENATE</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">loop_output_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">trip_limit</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">loop_output_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">loop_output_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="masked_scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_scatter">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">masked_scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add the masked_scatter base on PyTorch definition.</span>
|
||
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a</span>
|
||
<span class="sd"> description of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
|
||
|
||
<span class="sd"> source: Tensor</span>
|
||
<span class="sd"> The tensor to copy from</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor containing the source tensor selected by mask.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">>=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"input should have rank >= 1"</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
|
||
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">source</span> <span class="o">=</span> <span class="n">source</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="concat">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.concat">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">concat</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to concatenate tensors.</span>
|
||
|
||
<span class="sd"> The function creates an operation that concatenates the tensors from the</span>
|
||
<span class="sd"> sequence 'inputs'. The concatenation is done along the dimension 'dim'.</span>
|
||
|
||
<span class="sd"> All the tensors in 'inputs' must have the same shape expect for the</span>
|
||
<span class="sd"> dimension 'dim'.</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
|
||
|
||
<span class="sd"> The shape of the output tensor is defined as:</span>
|
||
|
||
<span class="sd"> for ii in range(inputs[0].rank()):</span>
|
||
<span class="sd"> # Same size as all the inputs in dimension ii != dim.</span>
|
||
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
|
||
|
||
<span class="sd"> # Sum of the sizes in the different inputs in dimension 'dim'.</span>
|
||
<span class="sd"> if ii == dim:</span>
|
||
<span class="sd"> for jj in range(1, len(inputs)):</span>
|
||
<span class="sd"> output.shape[ii] += inputs[jj].shape[ii]</span>
|
||
|
||
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
|
||
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
|
||
|
||
<span class="sd"> concat(inputs, 0)</span>
|
||
|
||
<span class="sd"> will produce [[0, 1], [2, 3], [4, 5], [6, 7]] of shape [4, 2] and</span>
|
||
|
||
<span class="sd"> concat(inputs, 1)</span>
|
||
|
||
<span class="sd"> will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Sequence[Union[Tensor, int]]</span>
|
||
<span class="sd"> The sequence of tensors to concatenate. For integers, that function</span>
|
||
<span class="sd"> creates constant tensors.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which the concatenation is performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that contains the concatenation of the tensors.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">inputs</span>
|
||
<span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Number of inputs (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span><span class="si">}</span><span class="s2">) to the concatenation layer must be > 0."</span>
|
||
<span class="n">tmp</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="o">*</span><span class="n">inputs</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">i</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_concatenation</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tmp</span><span class="p">])</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">tmp</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softmax">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute softmax on a tensor.</span>
|
||
|
||
<span class="sd"> That operation computes the softmax on the input tensor in the dimension</span>
|
||
<span class="sd"> 'dim' if specified. Otherwise, it is applied on the last dimension.</span>
|
||
|
||
<span class="sd"> It inserts a ISoftmaxLayer to the TensorRT graph.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which to apply softmax.</span>
|
||
|
||
<span class="sd"> dim : Optional[int]</span>
|
||
<span class="sd"> The dimension used to apply softmax.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the softmax layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform lookup in a tensor.</span>
|
||
|
||
<span class="sd"> That operation performs the lookup needed by embedding layers. Given a</span>
|
||
<span class="sd"> 'weight' tensor of shape [rows, cols], it produces a tensor of shape</span>
|
||
<span class="sd"> [inputs.size(0), cols] where the ith row corresponds to the input[i] row in</span>
|
||
<span class="sd"> the weight tensor.</span>
|
||
|
||
<span class="sd"> It inserts a IPluginV2Layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> rank : int</span>
|
||
<span class="sd"> The mpi rank.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of the lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lookup'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">per_token_scale</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">rank</span><span class="p">])</span>
|
||
<span class="n">lookup_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lookup"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">per_token_scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lookup_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"lookup"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="embedding">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.embedding">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">embedding</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform embedding lookup.</span>
|
||
|
||
<span class="sd"> That operation performs the embedding lookup. The 'input' tensor contains</span>
|
||
<span class="sd"> the identifiers of the rows of 'weight' to gather.</span>
|
||
|
||
<span class="sd"> 1. Distribute the embedding lookup table over multiple GPU</span>
|
||
<span class="sd"> When 'tp_size' is greater than 1 and the 'tp_group' is defined, this</span>
|
||
<span class="sd"> embedding lookup is distributed among multiple GPUs.</span>
|
||
|
||
<span class="sd"> When 'sharding_dim==0', each GPU stores a subset of the rows of the embedding</span>
|
||
<span class="sd"> table rows(that number of rows per GPU is given by weights.shape[0] and the offset to</span>
|
||
<span class="sd"> the 1st row stored on the GPU is given by rank * weights.shape[0]). Each</span>
|
||
<span class="sd"> parallel rank will query all the indices and set 0s for the weights that</span>
|
||
<span class="sd"> are not stored on the associated GPU. To compute the final result, a</span>
|
||
<span class="sd"> parallel all-reduce operation is added to the TensorRT graph. That lookup</span>
|
||
<span class="sd"> can be performed using either the plugin or the operators TensorRT support.</span>
|
||
|
||
<span class="sd"> When'sharding_dim==1', each GPU stores a subset of the embedding table's columns.</span>
|
||
<span class="sd"> Each rank can obtain a portion of the embedding results.</span>
|
||
<span class="sd"> Then the embedding is collected using the all-gather operation.</span>
|
||
<span class="sd"> Related transposition operations are also used to obtain the final results.</span>
|
||
|
||
<span class="sd"> 2. Store embedding lookup table as a whole</span>
|
||
<span class="sd"> When 'tp_size' is not greater than 1, the embedding lookup table will not</span>
|
||
<span class="sd"> be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,</span>
|
||
<span class="sd"> the operation is implemented using a plugin (without the all-reduce operation).</span>
|
||
<span class="sd"> Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The table to gather from.</span>
|
||
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The number of GPUs collaborating to perform that embedding.</span>
|
||
|
||
<span class="sd"> tg_group : Optional[List[int]]</span>
|
||
<span class="sd"> The group of world ranks participating in the all-reduce when</span>
|
||
<span class="sd"> tp_size > 1.</span>
|
||
|
||
<span class="sd"> sharding_dim : int</span>
|
||
<span class="sd"> sharding_dim = 0 means that we shard the embedding table in vocab dim;</span>
|
||
<span class="sd"> sharding_dim = 1 means that we shard the embedding table in embedding dim.</span>
|
||
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank. Used to calculate offset in TP on vocab dim.</span>
|
||
|
||
<span class="sd"> padding: Tensor</span>
|
||
<span class="sd"> Additional padding added to the end of the embedding table before feeding into gather op.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the embedding lookup layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Per token scale is only supported by lookup plugin so if per_token_scale is not None, we must use lookup plugin</span>
|
||
<span class="c1"># Otherwise, we prefer to use ootb</span>
|
||
<span class="n">use_lookup_plugin</span> <span class="o">=</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">padding</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">weight</span><span class="p">,</span> <span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">weight</span>
|
||
|
||
<span class="c1"># Distribute embedding lookup table across multiple GPU</span>
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># TP on vocab_size dimension</span>
|
||
<span class="k">if</span> <span class="n">tp_rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Rank cannot be none for tensor parallelism on vocab dim"</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">per_token_scale</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">shape_weight</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
|
||
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">shape_weight</span><span class="p">,</span> <span class="n">starts</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">tmp_input</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">vocab_size</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
|
||
<span class="c1"># Identify the valid indices</span>
|
||
<span class="n">is_qualified</span> <span class="o">=</span> <span class="n">op_and</span><span class="p">(</span><span class="n">tmp_input</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tmp_input</span> <span class="o"><</span> <span class="n">vocab_size</span><span class="p">)</span>
|
||
<span class="n">is_qualified_expand</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="n">is_qualified</span><span class="o">.</span><span class="n">ndim</span><span class="p">()])</span>
|
||
|
||
<span class="c1"># Replace the invalid ones to zero</span>
|
||
<span class="n">placeholder_input</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span> <span class="n">tmp_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Get the temporal results</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span>
|
||
<span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">placeholder_input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">tmp_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Set zero for invalid results</span>
|
||
<span class="n">placeholder_tmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">placeholder</span> <span class="o">=</span> <span class="n">placeholder_tmp</span> <span class="o">-</span> <span class="n">placeholder_tmp</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="p">,</span> <span class="n">placeholder</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Use all reduce to collect the results</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
|
||
|
||
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># TP on hidden dimension</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="c1"># [dim0, local_dim] -> [dim0 * tp_size, local_dim] --> [dim0, local_dim * tp_size]</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">,</span> <span class="n">gather_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s1">'Tensor Parallelism only support splitting Embedding lookup along hidden (sharding_dim==1) and vocab (sharding_dim==0) dimensionis'</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Store embedding lookup table as a whole</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
|
||
<span class="n">padded_weight</span><span class="p">,</span>
|
||
<span class="n">rank</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">per_token_scale</span><span class="o">=</span><span class="n">per_token_scale</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constant_to_tensor_">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant_to_tensor_">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">constant_to_tensor_</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span> <span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># deduce the type from the given value</span>
|
||
<span class="c1"># NOTE: bool is a subtype of int, so bool needs to be checked first</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">bool</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">bool</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">array_fn_dict</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span> <span class="n">int64_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="n">int32_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">fp32_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">fp16_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span> <span class="n">bf16_array</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">bool</span><span class="p">:</span> <span class="n">bool_array</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="n">array_fn_dict</span>
|
||
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn_dict</span><span class="p">[</span><span class="n">dtype</span><span class="p">]([</span><span class="nb">input</span><span class="p">]</span> <span class="k">if</span> <span class="n">to_array</span> <span class="k">else</span> <span class="nb">input</span><span class="p">))</span>
|
||
|
||
<span class="k">return</span> <span class="nb">input</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="constants_to_tensors_">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constants_to_tensors_">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">constants_to_tensors_</span><span class="p">(</span>
|
||
<span class="o">*</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">...</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Helper function to create tensors from multiple inputs.</span>
|
||
|
||
<span class="sd"> For each inputs, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if any input is int64, it upcasts other</span>
|
||
<span class="sd"> integer inputs to int64.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> inputs : Tuple[Union[Tensor, int, float], ...]</span>
|
||
<span class="sd"> The inputs to create tensors from.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tuple of tensors.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">has_int64</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">i</span> <span class="o">>=</span> <span class="mi">2</span><span class="o">**</span><span class="mi">31</span> <span class="ow">or</span> <span class="n">i</span> <span class="o"><</span> <span class="o">-</span><span class="mi">2</span><span class="o">**</span><span class="mi">31</span><span class="p">)</span>\
|
||
<span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
|
||
<span class="n">has_int64</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">break</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">has_int64</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">)</span>
|
||
|
||
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="k">if</span> <span class="n">has_int64</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">result</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="broadcast_helper">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.broadcast_helper">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Helper function to perform a broadcast.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A pair of tensors of same rank.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">></span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="elementwise_binary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.elementwise_binary">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">elementwise_binary</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation with two inputs.</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
|
||
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
|
||
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
|
||
<span class="sd"> elementwise operation 'op'.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> add for op=trt.ElementWiseOperation.SUM</span>
|
||
<span class="sd"> sub for op=trt.ElementWiseOperation.SUB</span>
|
||
<span class="sd"> mul for op=trt.ElementWiseOperation.PROD</span>
|
||
<span class="sd"> div for op=trt.ElementWiseOperation.DIV</span>
|
||
<span class="sd"> floordiv for op=trt.ElementWiseOperation.FLOOR_DIV</span>
|
||
<span class="sd"> gt for op=trt.ElementWiseOperation.GREATER</span>
|
||
<span class="sd"> lt for op=trt.ElementWiseOperation.LESS</span>
|
||
<span class="sd"> op_and for op=trt.ElementWiseOperation.AND</span>
|
||
<span class="sd"> op_or for op=trt.ElementWiseOperation.OR</span>
|
||
<span class="sd"> eq for op=trt.ElementWiseOperation.EQUAL</span>
|
||
<span class="sd"> minimum for op=trt.ElementWiseOperation.MIN</span>
|
||
<span class="sd"> maximum for op=trt.ElementWiseOperation.MAX</span>
|
||
<span class="sd"> pow for op=trt.ElementWiseOperation.POW</span>
|
||
|
||
<span class="sd"> It is implemented using the IElementWiseLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ElementWiseOperation</span>
|
||
<span class="sd"> The binary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span><span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">add</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
|
||
<span class="n">sub</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUB</span><span class="p">)</span>
|
||
<span class="n">mul</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
|
||
<span class="n">div</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">DIV</span><span class="p">)</span>
|
||
<span class="n">floordiv</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">FLOOR_DIV</span><span class="p">)</span>
|
||
<span class="n">gt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">GREATER</span><span class="p">)</span>
|
||
<span class="n">lt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">LESS</span><span class="p">)</span>
|
||
<span class="n">op_and</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">AND</span><span class="p">)</span>
|
||
<span class="n">op_or</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">OR</span><span class="p">)</span>
|
||
<span class="n">eq</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">EQUAL</span><span class="p">)</span>
|
||
<span class="n">minimum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
|
||
<span class="n">maximum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">)</span>
|
||
<span class="nb">pow</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">POW</span><span class="p">)</span>
|
||
<span class="n">op_xor</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">XOR</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="modulo">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.modulo">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">modulo</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This function adds an element-wise modulo (x % y) operation for a given tensor.</span>
|
||
<span class="sd"> Since there is no TensorRT layer that can directly perform this,</span>
|
||
<span class="sd"> this function implements it using some of the basic operations.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor that represents (x % y) modulo operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="p">(</span><span class="n">x</span> <span class="o">//</span> <span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="where">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.where">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">where</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span> <span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a where (aka select or if-then-else) operation.</span>
|
||
|
||
<span class="sd"> Assuming the three input parameters have the same shape, that function creates</span>
|
||
<span class="sd"> the operation to compute a tensor of the same shape such that:</span>
|
||
|
||
<span class="sd"> for ii in range(mul(condition.shape)):</span>
|
||
<span class="sd"> output[ii] = left[ii] if condition[ii] else right[ii]</span>
|
||
|
||
<span class="sd"> For each input, that function first creates a constant tensor if the</span>
|
||
<span class="sd"> condition is boolean or the left/right input is an integer or a float.</span>
|
||
<span class="sd"> Then, if needed, it expands the smaller tensor to make sure its</span>
|
||
<span class="sd"> rank is the same as the larger one. Then, it performs the selection.</span>
|
||
|
||
<span class="sd"> It is implemented using the ISelectLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> condition : Union[Tensor, bool]</span>
|
||
<span class="sd"> The condition. If that input is a boolean, the function</span>
|
||
<span class="sd"> creates a constant tensor.</span>
|
||
|
||
<span class="sd"> left : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The first input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> right : Union[Tensor, int, float]</span>
|
||
<span class="sd"> The second input. If that input is an integer or a float, the</span>
|
||
<span class="sd"> function creates a constant tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this where operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># Convert to tensors.</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
|
||
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Find the tensor with the largest rank of the three.</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">condition</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">left</span>
|
||
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o"><</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
|
||
<span class="n">largest</span> <span class="o">=</span> <span class="n">right</span>
|
||
|
||
<span class="c1"># Expand the tensors to match the largest one.</span>
|
||
<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">condition</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">left</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">right</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
|
||
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Insert the operation.</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_select</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unary">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unary">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">unary</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an elementwise operation on a single input.</span>
|
||
|
||
<span class="sd"> The following closures are defined in functional.*:</span>
|
||
|
||
<span class="sd"> round for op=trt.UnaryOperation.ROUND</span>
|
||
<span class="sd"> sqrt for op=trt.UnaryOperation.SQRT</span>
|
||
<span class="sd"> exp for op=trt.UnaryOperation.EXP</span>
|
||
<span class="sd"> sin for op=trt.UnaryOperation.SIN</span>
|
||
<span class="sd"> cos for op=trt.UnaryOperation.COS</span>
|
||
<span class="sd"> abs for op=trt.UnaryOperation.ABS</span>
|
||
<span class="sd"> log for op=trt.UnaryOperation.LOG</span>
|
||
|
||
<span class="sd"> It is implemented using the IUnaryLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> op : trt.UnaryOperation</span>
|
||
<span class="sd"> The unary operation to perform.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this elementwise operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_unary</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="nb">round</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ROUND</span><span class="p">)</span>
|
||
<span class="n">sqrt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SQRT</span><span class="p">)</span>
|
||
<span class="n">exp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">EXP</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SIN</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">COS</span><span class="p">)</span>
|
||
<span class="nb">abs</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ABS</span><span class="p">)</span>
|
||
<span class="n">log</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">LOG</span><span class="p">)</span>
|
||
<span class="n">not_op</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">NOT</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="log_softmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.log_softmax">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">log_softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> This function is equivalent of torch.nn.functional.log_softmax() i.e.</span>
|
||
<span class="sd"> it performs log(softmax(input)) in a safer and faster way.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The data tensor on which log_softmax to be computed.</span>
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The dimension of the input tensor along which log_softmax will be computed.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor of same shape as input with log_softmax computed on the specified dim.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">x_max</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">x_max</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="n">log</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="reduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an reduction operation to do along a dimension.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> op : trt.ReduceOperation</span>
|
||
<span class="sd"> The reduction operation to perform.</span>
|
||
<span class="sd"> Options: SUM, PROD, MAX, MIN, AVG</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the reduction is performed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">op</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="p">,</span>
|
||
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="n">prod</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
|
||
<span class="nb">min</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="mean">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mean">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mean</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the mean along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the mean along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="max">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.max">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">max</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the max along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the max along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="sum">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.sum">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">sum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the sum along a dimension.</span>
|
||
|
||
<span class="sd"> Computes the sum along the dimension 'dim' of the input tensor.</span>
|
||
|
||
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which the mean is computed.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
|
||
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this reduction operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="identity">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.identity">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">identity</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an identity operation.</span>
|
||
|
||
<span class="sd"> TODO: Document why it can be done using a plugin!!!</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this identity operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">identity_plugin</span><span class="p">:</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_identity</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Identity'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">()</span>
|
||
<span class="n">id_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"identity"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">id_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"identity"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="argmax">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.argmax">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">argmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an argmax operation.</span>
|
||
|
||
<span class="sd"> As explained in the ONNX documentation,</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax</span>
|
||
|
||
<span class="sd"> that function creates a layer computing the indices of the max elements of</span>
|
||
<span class="sd"> the input tensor's element along the provided dim. The resulting tensor</span>
|
||
<span class="sd"> has the same rank as the input if keepdims is True. If keepdims is False,</span>
|
||
<span class="sd"> then the resulting tensor has the reduced dimension pruned.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension in which to compute the argmax indices.</span>
|
||
|
||
<span class="sd"> keepdim : bool</span>
|
||
<span class="sd"> Do we keep the dimension along which the reduction is performed?</span>
|
||
<span class="sd"> Yes, if set to True, no otherwise.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by this argmax operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">keepdim</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
|
||
<span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
|
||
<span class="n">a</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
|
||
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">output_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gelu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a GELU operation.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span>
|
||
<span class="n">tanh</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="geglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.geglu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">geglu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a Gated-GELU operation.</span>
|
||
|
||
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
|
||
<span class="sd"> dimension, applies GELU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor on which the activation function is applied.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the activation layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">gelu</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="quick_gelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.quick_gelu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">quick_gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="mf">1.702</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gegelu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gegelu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gegelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">limit</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1"># a, b = x[..., ::2], x[..., 1::2]</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">a_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">b_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">shapes</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
<span class="n">strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="n">a</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">b_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">limit</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">a</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="nb">float</span><span class="p">(</span><span class="o">-</span><span class="mf">1e20</span><span class="p">),</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
|
||
<span class="n">b</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="n">limit</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
|
||
|
||
<span class="c1"># C = B + 1</span>
|
||
<span class="n">const1</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span>
|
||
<span class="n">trt_dtype_to_str</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">const1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">b_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="n">const1_arr</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="n">b_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">quick_gelu</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">b</span> <span class="o">+</span> <span class="n">const1_arr</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="group_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.group_norm">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">group_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">):</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">,</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="c1"># instance norm</span>
|
||
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span>
|
||
<span class="n">instance_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">instance_bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
|
||
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">instance_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">instance_bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">axes_mask</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">num_channels</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="softplus">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softplus">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">softplus</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add the softplus activation base on PyTorch definition.</span>
|
||
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html#torch-nn-functional-softplus for a</span>
|
||
<span class="sd"> description of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> Input TensorRT-LLM Tensor.</span>
|
||
<span class="sd"> beta : float</span>
|
||
<span class="sd"> The parameter for softplus computation.</span>
|
||
<span class="sd"> threshold : float</span>
|
||
<span class="sd"> The threshold for reverting to the linear function when input * beta > threshold</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor created by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">sf_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SOFTPLUS</span><span class="p">)</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">beta</span>
|
||
<span class="n">sf_layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
|
||
|
||
<span class="n">prod_tensor</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">beta</span>
|
||
<span class="n">result</span> <span class="o">=</span> <span class="n">prod_tensor</span> <span class="o">></span> <span class="n">threshold</span>
|
||
|
||
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">sf_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">sf_layer</span><span class="p">))</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="outer">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.outer">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">outer</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">vec2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to compute the outer product between two tensors.</span>
|
||
|
||
<span class="sd"> That operation creates an Einsum node.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first input tensor.</span>
|
||
|
||
<span class="sd"> vec2 : Tensor</span>
|
||
<span class="sd"> The second input tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor produced by this layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="s1">'i,j->ij'</span><span class="p">,</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">vec2</span><span class="p">])</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="avg_pool2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.avg_pool2d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">avg_pool2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">ceil_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">count_include_pad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_pooling_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PoolingType</span><span class="o">.</span><span class="n">AVERAGE</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">stride</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">stride</span> <span class="o">=</span> <span class="n">kernel_size</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv1d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv1d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">input_shuffled</span> <span class="o">=</span> <span class="n">stack</span><span class="p">([</span><span class="nb">input</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">([</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="n">input_shuffled</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">noutput</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output_2d</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">output_1d</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">output_2d</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output_1d</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv2d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">pre_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">post_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="k">if</span> <span class="n">pre_padding</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">pre_padding</span> <span class="o">=</span> <span class="n">pre_padding</span>
|
||
<span class="k">if</span> <span class="n">post_padding</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">post_padding</span> <span class="o">=</span> <span class="n">post_padding</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv3d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv3d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">conv3d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document this function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="c1"># TRT requires the input of Conv3D layer to be 5-dimentional tensor.</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">5</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">stride</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">stride</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">padding</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">padding</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="n">dilation</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">dilation</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">3</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="conv_transpose2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv_transpose2d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">conv_transpose2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">output_padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1">##</span>
|
||
<span class="c1">## TODO: Document that function!</span>
|
||
<span class="c1">##</span>
|
||
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
|
||
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_deconvolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
|
||
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
|
||
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="split">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.split">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">split_size_or_sections</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by slicing it along the dimension 'dim'. If 'split_size_or_sections'</span>
|
||
<span class="sd"> is an integer, the tensor is split into 'input.shape[dim] /</span>
|
||
<span class="sd"> split_size_or_sections' slices. If 'split_size_or_sections' is a list of</span>
|
||
<span class="sd"> sizes, the tensor is split into 'len(split_size_or_sections)' slices and</span>
|
||
<span class="sd"> the size of the ith slice is given by 'split_size_or_sections[i]'.</span>
|
||
|
||
<span class="sd"> There are several constraints with the current implementation:</span>
|
||
|
||
<span class="sd"> - The input tensor must be static (no dynamic dimension),</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is an integer, the number of elements in</span>
|
||
<span class="sd"> the 'dim' dimension of the input must be a multiple of</span>
|
||
<span class="sd"> 'split_size_or_sections': 'input.shape[dim] % split_size_or_sections == 0'.</span>
|
||
<span class="sd"> - If 'split_size_or_sections' is a sequence, the sum of the elements in</span>
|
||
<span class="sd"> 'split_size_or_sections' must be equal to the size in the dimension</span>
|
||
<span class="sd"> 'dim': 'input.shape[dim] == sum(ii for ii in split_size_or_sections)'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a 'slice' operation for each output</span>
|
||
<span class="sd"> slice.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> split_size_or_sections : Union[int, Sequence[int]]</span>
|
||
<span class="sd"> If it is an integer, it encodes the size of each slice. Otherwise,</span>
|
||
<span class="sd"> if it is a sequence, it is the size of each slice.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="c1"># TODO: support non-divisible cases</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">split_size_or_sections</span> <span class="o">==</span> <span class="mi">0</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">split_size_or_sections</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">]))</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span> <span class="o">*</span> <span class="n">i</span><span class="p">]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">split_size_or_sections</span><span class="p">:</span>
|
||
<span class="n">total_size</span> <span class="o">+=</span> <span class="n">i</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">==</span> <span class="n">total_size</span>
|
||
<span class="n">num_sections</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">)</span>
|
||
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">+</span> <span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">[</span><span class="n">i</span><span class="p">]]))</span>
|
||
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="k">return</span> <span class="n">outputs</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="chunk">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.chunk">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
|
||
|
||
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
|
||
<span class="sd"> tensor by chunking it along the dimension 'dim'. It produces 'chunks'</span>
|
||
<span class="sd"> sub-tensors.</span>
|
||
|
||
<span class="sd"> That operation is only defined for static tensors (no dynamic dimension)</span>
|
||
<span class="sd"> and the size of the tensor in the dimension 'dim' must be a multiple of</span>
|
||
<span class="sd"> 'chunks': 'input.shape[dim] % chunks == 0'.</span>
|
||
|
||
<span class="sd"> It maps to 'split' with 'split_size = input.shape[dim] / chunks'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor to slice.</span>
|
||
|
||
<span class="sd"> chunks : int</span>
|
||
<span class="sd"> The number of slices to split the input tensor into.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension of the tensor to slice.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The list of tensors produced by the different operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
|
||
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span>
|
||
|
||
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="unbind">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unbind">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">unbind</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Removes a tensor dimension.</span>
|
||
|
||
<span class="sd"> Returns a tuple of all slices along a given dimension, already without it.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">output_shape</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">output_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">]</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceStrategy">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceStrategy">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceStrategy</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">NCCL</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">MIN_LATENCY</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">UB</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">ONESHOT</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">TWOSHOT</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">LOWPRECISION</span> <span class="o">=</span> <span class="mi">6</span>
|
||
<span class="n">MNNVL</span> <span class="o">=</span> <span class="mi">7</span>
|
||
<span class="n">NCCL_SYMMETRIC</span> <span class="o">=</span> <span class="mi">8</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceFusionOp">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceFusionOp">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceFusionOp</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">NONE</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">LAST_PROCESS_FOR_UB</span> <span class="o">=</span> <span class="mi">2</span>
|
||
<span class="n">RESIDUAL_RMS_PREPOST_NORM</span> <span class="o">=</span> <span class="mi">3</span>
|
||
<span class="n">RESIDUAL_RMS_NORM_QUANT_FP8</span> <span class="o">=</span> <span class="mi">4</span>
|
||
<span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span> <span class="o">=</span> <span class="mi">5</span>
|
||
<span class="n">RESIDUAL_RMS_NORM_OUT_QUANT_FP8</span> <span class="o">=</span> <span class="mi">6</span>
|
||
<span class="n">RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4</span> <span class="o">=</span> <span class="mi">7</span>
|
||
<span class="n">MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">8</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceParams</span><span class="p">():</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">strategy</span><span class="p">:</span> <span class="n">AllReduceStrategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span><span class="p">,</span>
|
||
<span class="n">fusion_op</span><span class="p">:</span> <span class="n">AllReduceFusionOp</span> <span class="o">=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">,</span>
|
||
<span class="n">enable_allreduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">trigger_completion_at_end</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">strategy</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">=</span> <span class="n">fusion_op</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">residual</span> <span class="o">=</span> <span class="n">residual</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="o">=</span> <span class="n">norm_weight</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span> <span class="o">=</span> <span class="n">norm_pre_residual_weight</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="c1"># For torch path only, has no effect on TRT path</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">enable_allreduce</span> <span class="o">=</span> <span class="n">enable_allreduce</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">trigger_completion_at_end</span> <span class="o">=</span> <span class="n">trigger_completion_at_end</span>
|
||
<span class="k">assert</span> <span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="o">.</span><span class="n">value</span> <span class="ow">or</span> <span class="p">(</span><span class="n">residual</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_affine">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_affine">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">has_affine</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_bias">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_bias">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">has_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.has_scale">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_scale">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">has_scale</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AllReduceParams.update_strategy">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.update_strategy">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">update_strategy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MoEAllReduceParams">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MoEAllReduceParams">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MoEAllReduceParams</span><span class="p">(</span><span class="n">AllReduceParams</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">device_num_experts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">expert_scale_factor</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">expanded_idx_to_permuted_idx</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">shared_expert_output</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">,</span>
|
||
<span class="n">enable_allreduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">is_cutlass_min_latency</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
|
||
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
|
||
<span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
|
||
<span class="n">norm_weight</span><span class="o">=</span><span class="n">norm_weight</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">,</span>
|
||
<span class="n">norm_pre_residual_weight</span><span class="o">=</span><span class="n">norm_pre_residual_weight</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span>
|
||
<span class="n">enable_allreduce</span><span class="o">=</span><span class="n">enable_allreduce</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">device_num_experts</span> <span class="o">=</span> <span class="n">device_num_experts</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">expert_scale_factor</span> <span class="o">=</span> <span class="n">expert_scale_factor</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">expanded_idx_to_permuted_idx</span> <span class="o">=</span> <span class="n">expanded_idx_to_permuted_idx</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">shared_expert_output</span> <span class="o">=</span> <span class="n">shared_expert_output</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">is_cutlass_min_latency</span> <span class="o">=</span> <span class="n">is_cutlass_min_latency</span>
|
||
|
||
<div class="viewcode-block" id="MoEAllReduceParams.is_valid">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MoEAllReduceParams.is_valid">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_valid</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_cutlass_min_latency</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device_num_experts</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_expert_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">expanded_idx_to_permuted_idx</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="create_allreduce_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.create_allreduce_plugin">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_allreduce_plugin</span><span class="p">(</span>
|
||
<span class="n">network</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">INetworkDefinition</span><span class="p">,</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span>
|
||
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">],</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
|
||
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">AllReduceParams</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="n">allreduce_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllReduce'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">allreduce_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">pf_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="p">[</span><span class="n">pf_group</span><span class="p">,</span> <span class="n">pf_dtype</span><span class="p">]</span>
|
||
<span class="n">p_strategy</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"strategy"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_strategy</span><span class="p">)</span>
|
||
<span class="n">p_fusion_op</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"fusion_op"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_fusion_op</span><span class="p">)</span>
|
||
<span class="n">p_eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"eps"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">eps</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_eps</span><span class="p">)</span>
|
||
<span class="n">p_affine</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"affine"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_affine</span><span class="p">)</span>
|
||
<span class="n">p_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"bias"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_bias</span><span class="p">)</span>
|
||
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"scale"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_scale</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">ar_plug</span> <span class="o">=</span> <span class="n">allreduce_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allreduce"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">workspace</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">residual</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">ar_plug</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span></div>
|
||
|
||
|
||
|
||
<span class="n">allreduce_ub_counter</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
|
||
<div class="viewcode-block" id="allreduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allreduce">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">allreduce</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceParams</span><span class="p">]</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-reduce.</span>
|
||
|
||
<span class="sd"> Let's define 'world_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to compute the sum of 'world_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'world_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the output</span>
|
||
<span class="sd"> tensor will have that same shape. The output tensor will be replicated on</span>
|
||
<span class="sd"> the 'world_size' ranks.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-reduce</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-reduce operation.</span>
|
||
|
||
<span class="sd"> strategy: AllReduceStrategy</span>
|
||
<span class="sd"> NCCL delegates all-reduce to NCCL while ONESHOT and TWOSHOT are custom latency-optimal algorithms.</span>
|
||
<span class="sd"> AUTO chooses amongst the three based on a message-size heuristic.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">global</span> <span class="n">allreduce_ub_counter</span>
|
||
<span class="n">allreduce_ub_counter</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">all_reduce_params</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">update_strategy</span><span class="p">()</span>
|
||
|
||
<span class="c1"># TODO(TRTLLM-996): remove this WAR when custom allreduce is supported</span>
|
||
<span class="c1"># for encoder models in C++ runtime.</span>
|
||
<span class="n">workspace</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">workspace</span> <span class="o">=</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_0_"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span> <span class="o">=</span> <span class="n">create_allreduce_plugin</span><span class="p">(</span>
|
||
<span class="n">network</span><span class="o">=</span><span class="n">default_trtnet</span><span class="p">(),</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">workspace</span><span class="o">=</span><span class="n">workspace</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">all_reduce_params</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="s2">"allreduce"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
|
||
<span class="n">inter_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span><span class="p">:</span>
|
||
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">final_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_1_"</span> <span class="o">+</span>
|
||
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span><span class="p">:</span>
|
||
<span class="n">scale_factor</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_2_"</span> <span class="o">+</span>
|
||
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">final_output</span><span class="p">,</span> <span class="n">scale_factor</span><span class="p">),</span> <span class="n">inter_output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">LAST_PROCESS_FOR_UB</span>
|
||
<span class="n">inter_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">"allreduce_ub_1_"</span> <span class="o">+</span>
|
||
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">inter_output</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">final_output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="allgather">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allgather">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">allgather</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">gather_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a collective all-gather.</span>
|
||
|
||
<span class="sd"> Let's define 'group_size' as the length of the 'group' list. That functions</span>
|
||
<span class="sd"> creates a layer to gather 'group_size' tensors distributed</span>
|
||
<span class="sd"> amongst the 'group_size' participating ranks (one GPU per rank).</span>
|
||
|
||
<span class="sd"> The list 'group' contains the identifiers of the ranks participating into</span>
|
||
<span class="sd"> the collective operation.</span>
|
||
|
||
<span class="sd"> Note that 'group' here can be either TP group or PP group, because allgather communication is not limited to a specific split pattern. Therefore 'group_size' does not need to equal MPI 'world_size'.</span>
|
||
|
||
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the</span>
|
||
<span class="sd"> output tensor will have that same shape.</span>
|
||
|
||
<span class="sd"> Given the 'section_size = input.shape[0] / group_size', each rank</span>
|
||
<span class="sd"> contributes a section of its input tensor that correspond to</span>
|
||
<span class="sd"> 'rank*section_size:(rank+1)*section_size'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-gather</span>
|
||
<span class="sd"> collective operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> group : List[int]</span>
|
||
<span class="sd"> The ranks participating into the all-gather operation.</span>
|
||
|
||
<span class="sd"> gather_dim: int = 0</span>
|
||
<span class="sd"> Gather along given dimension. By default 0, i.e. treated as 1D tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">allgather_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'AllGather'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">allgather_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">group</span><span class="p">)</span>
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">allgather</span> <span class="o">=</span> <span class="n">allgather_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"allgather"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">allgather</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allgather_plg_creator</span><span class="p">,</span> <span class="s2">"allgather"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="c1"># gather along a given dimension other than dim0</span>
|
||
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># also support -1 type of dim representation</span>
|
||
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">gather_dim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">gather_dim</span>
|
||
|
||
<span class="c1"># plugin above gathers as 1D flattened tensor</span>
|
||
<span class="c1"># 1. [dim0, ...dimi, ...dimN] -> [group_size * dim0, ...dimi, ...dimN]</span>
|
||
|
||
<span class="c1"># now we need to gather-by-dim via split-concat</span>
|
||
<span class="c1"># 2. [group_size * dim0, ...dimi, ...dimN] -> [dim0, ...group_size * dimi, ...dimN]</span>
|
||
<span class="c1"># 2.1 split</span>
|
||
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">group_size</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
|
||
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
|
||
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">group_size</span><span class="p">):</span>
|
||
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
|
||
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
|
||
<span class="c1"># 2.2 concat</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">gather_dim</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="reduce_scatter">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce_scatter">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">reduce_scatter</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">plg_creater</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'ReduceScatter'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creater</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"group"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
|
||
<span class="n">reduce_scatter_plug</span> <span class="o">=</span> <span class="n">plg_creater</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"reduce_scatter"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">reduce_scatter_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creater</span><span class="p">,</span> <span class="s2">"reduce_scatter"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="send">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.send">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">send</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a send from a rank to another.</span>
|
||
|
||
<span class="sd"> The send operation sends a tensor from one rank to another. If a rank 'i'</span>
|
||
<span class="sd"> sends a tensor to a rank 'j', the rank 'j' must have a corresponding 'recv'</span>
|
||
<span class="sd"> operation from rank 'i'. See 'recv'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL send</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> tgt : int</span>
|
||
<span class="sd"> The rank that receives the tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">send_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Send'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">send_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">tgt</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tgt_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">tgt</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">send_plug</span> <span class="o">=</span> <span class="n">send_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"send"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">send_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">send_plg_creator</span><span class="p">,</span> <span class="s2">"send"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="recv">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.recv">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">recv</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs a recv to a rank from another.</span>
|
||
|
||
<span class="sd"> The recv operation receives a tensor from on a rank from another. If a rank 'i'</span>
|
||
<span class="sd"> receives a tensor from a rank 'j', the rank 'j' must have a corresponding 'send'</span>
|
||
<span class="sd"> operation to rank 'j'. See 'send'.</span>
|
||
|
||
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL recv</span>
|
||
<span class="sd"> point-to-point operation. See</span>
|
||
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv</span>
|
||
<span class="sd"> for details.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> src : int</span>
|
||
<span class="sd"> The rank that sends the tensor to.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">recv_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Recv'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">recv_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">src</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"src_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">src</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">recv_plug</span> <span class="o">=</span> <span class="n">recv_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"recv"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">recv_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">recv_plg_creator</span><span class="p">,</span> <span class="s2">"recv"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gemm_allreduce">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gemm_allreduce">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_allreduce</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">b</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">output_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">fp8_inputs_override</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">a_sf</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">b_sf</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs fused GEMM+AllReduce.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> a: Tensor</span>
|
||
<span class="sd"> Input tensor A</span>
|
||
<span class="sd"> b: Tensor</span>
|
||
<span class="sd"> Input tensor B</span>
|
||
<span class="sd"> a_sf: Optional[Tensor]</span>
|
||
<span class="sd"> Input tensor for scaling input A</span>
|
||
<span class="sd"> b_sf: Optional[Tensor]</span>
|
||
<span class="sd"> Input tensor for scaling input B</span>
|
||
<span class="sd"> group: List[int]</span>
|
||
<span class="sd"> Ranks participating in collective</span>
|
||
<span class="sd"> transa: bool</span>
|
||
<span class="sd"> Whether or not input tensor A is transposed</span>
|
||
<span class="sd"> transb: bool</span>
|
||
<span class="sd"> Whether or not input tensor B is transposed</span>
|
||
<span class="sd"> alpha: float</span>
|
||
<span class="sd"> Alpha for GEMM -> beta * C + (alpha * acc)</span>
|
||
<span class="sd"> output_dtype: trt.DataType</span>
|
||
<span class="sd"> Output type for plugin. If it is None, we</span>
|
||
<span class="sd"> will use type set in plugin_config.</span>
|
||
<span class="sd"> fp8_inputs_override: bool</span>
|
||
<span class="sd"> TRT graph does not detect FP8 inputs correctly. This</span>
|
||
<span class="sd"> flag is used to override the derived input tensor</span>
|
||
<span class="sd"> types so that our plugin knows to issue FP8 MMAs.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> Returns GEMM output tensor which has been reduced across ranks.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="c1"># Output tensor needs to be bound to externally managed</span>
|
||
<span class="c1"># memory so keep track of layer index so we can assign</span>
|
||
<span class="c1"># output tensor unique label.</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">gemm_allreduce</span><span class="p">,</span> <span class="s1">'layer_idx'</span><span class="p">):</span>
|
||
<span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="c1"># Check inputs</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">fp8_inputs_override</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="p">(</span>
|
||
<span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span> <span class="o">==</span> <span class="mi">1</span>
|
||
<span class="p">),</span> <span class="s2">"`alpha` must be passed as a float32 ndarray if `fp8_inputs_override` is enabled for gemm_allreduce_plugin"</span>
|
||
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
|
||
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
|
||
|
||
<span class="k">if</span> <span class="n">output_dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">output_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span>
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">output_dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">]</span>
|
||
|
||
<span class="n">alpha_is_tensor</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">alpha</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">alpha_is_tensor</span><span class="p">:</span>
|
||
<span class="n">alpha_value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">alpha_value</span> <span class="o">=</span> <span class="n">alpha</span>
|
||
|
||
<span class="n">plugin_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GemmAllReduce'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plugin_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">trt_type_a</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span> <span class="k">if</span> <span class="n">fp8_inputs_override</span> <span class="k">else</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">trt_type_b</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span> <span class="k">if</span> <span class="n">fp8_inputs_override</span> <span class="k">else</span> <span class="n">b</span><span class="o">.</span><span class="n">dtype</span>
|
||
|
||
<span class="c1"># create plugin fields</span>
|
||
<span class="n">field_list</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'type_a'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">trt_type_a</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'type_b'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">trt_type_b</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'type_d'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">output_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'transa'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'transb'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'group'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'has_sfa'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">a_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'has_sfb'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">b_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'alpha_is_ptr'</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">alpha_is_tensor</span><span class="p">)],</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
|
||
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">'alpha'</span><span class="p">,</span> <span class="n">alpha_value</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">))</span>
|
||
|
||
<span class="c1"># create plugin</span>
|
||
<span class="n">fields</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">field_list</span><span class="p">)</span>
|
||
<span class="n">plugin</span> <span class="o">=</span> <span class="n">plugin_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"gemm_allreduce"</span><span class="p">,</span> <span class="n">fields</span><span class="p">)</span>
|
||
<span class="c1"># define symbolic input tensors.</span>
|
||
<span class="c1"># note this does NOT allocate memory.</span>
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">b</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">a_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">a_sf</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">b_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">b_sf</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">alpha_is_tensor</span><span class="p">:</span>
|
||
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alpha</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">plugin</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plugin_creator</span><span class="p">,</span> <span class="s2">"gemm_allreduce"</span><span class="p">,</span> <span class="n">fields</span><span class="p">)</span>
|
||
<span class="c1"># define symbolic output tensors</span>
|
||
<span class="c1"># both output tensors point to same physical memory but</span>
|
||
<span class="c1"># one has unicast address and other has multicast address</span>
|
||
<span class="n">uc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">mc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">ipc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">uc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">mc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">ipc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="c1"># mark outputs so that we can bind our own allocated memory in runtime</span>
|
||
<span class="c1"># (see generation.py)</span>
|
||
<span class="n">uc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">'gemm_allreduce_uc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">mc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">'gemm_allreduce_mc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">ipc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">'gemm_allreduce_ipc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span>
|
||
<span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="n">uc_output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="bert_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.bert_attention">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">bert_attention</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">relative_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">max_input_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sage_attn</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">sage_attn_q_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">sage_attn_k_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">sage_attn_v_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">cp_group</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in BERT.</span>
|
||
|
||
<span class="sd"> The multi-head attention (MHA) is the sequence of a batched matmul, a</span>
|
||
<span class="sd"> softmax and a batched matmul as described in</span>
|
||
<span class="sd"> https://arxiv.org/abs/1706.03762. That function adds an operation that</span>
|
||
<span class="sd"> performs those computations using a single GPU kernel.</span>
|
||
|
||
<span class="sd"> The input tensor contains the Q, K and V elements. It is a 2D tensor and</span>
|
||
<span class="sd"> its shape is '[sum_of_tokens, 3*hidden_dim]' where the 'sum_of_tokens' is</span>
|
||
<span class="sd"> the sum of the sequence lengths in the batch.</span>
|
||
|
||
<span class="sd"> In MHA, the output of the Q*K^T product is scaled by a constant value that</span>
|
||
<span class="sd"> is computed as:</span>
|
||
|
||
<span class="sd"> 1.f / (q_scaling * sqrt(head_size)).</span>
|
||
|
||
<span class="sd"> That 'q_scaling' constant is the last argument of that function.</span>
|
||
|
||
<span class="sd"> That layer is implemented using a plugin (see bertAttentionPlugin).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> tensor : Tensor</span>
|
||
<span class="sd"> The QKV input tensor.</span>
|
||
|
||
<span class="sd"> input_lengths : Tensor</span>
|
||
<span class="sd"> The length of each sequence. It is a 1D tensor of size 'batch_size'.</span>
|
||
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
|
||
<span class="sd"> head_size : int</span>
|
||
<span class="sd"> The size of each head.</span>
|
||
|
||
<span class="sd"> q_scaling : float</span>
|
||
<span class="sd"> The factor to compute the scaling factor to scale the output of the</span>
|
||
<span class="sd"> 'Q*K^T' product.</span>
|
||
|
||
<span class="sd"> relative_attention: bool = False</span>
|
||
<span class="sd"> If enable relative attention.</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> max_input_length: Tensor = None</span>
|
||
<span class="sd"> The maximum input sequence length represented by Tensor shape. Requires for remove_input_padding to pre-define plugin workspace size.</span>
|
||
|
||
<span class="sd"> sage_attn: bool = False</span>
|
||
<span class="sd"> SageAttention is a 8-bit implementation of attention kernel. It's input q, k, v and output datatypes are 16-bit. It performance dynamic quantization for q, k, v</span>
|
||
<span class="sd"> tensor every time before attention. https://github.com/thu-ml/SageAttention</span>
|
||
|
||
<span class="sd"> sage_attn_q_quant_size: int = 0</span>
|
||
<span class="sd"> dynamic quant block size along sequence dimension of q tensor. Each quant block will share one scale.</span>
|
||
|
||
<span class="sd"> sage_attn_k_quant_size: int = 0</span>
|
||
<span class="sd"> dynamic quant block size along sequence dimension of k tensor. Each quant block will share one scale.</span>
|
||
|
||
<span class="sd"> sage_attn_v_quant_size: int = 0</span>
|
||
<span class="sd"> dynamic quant block size along sequence dimension of v tensor. Each quant block will share one scale.</span>
|
||
|
||
<span class="sd"> cp_group: list[int] = None</span>
|
||
<span class="sd"> The communication group for context parallel</span>
|
||
|
||
<span class="sd"> cp_size: int = 1</span>
|
||
<span class="sd"> The communication size for context parallel</span>
|
||
|
||
<span class="sd"> cp_rank: int = 0</span>
|
||
<span class="sd"> The communication rank for context parallel</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'BertAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">head_size</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">bert_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">do_relative_attention</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_relative_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">relative_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">sage_attn</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"sage_attn"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">sage_attn</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">sage_attn_q_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"sage_attn_q_block_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_q_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">sage_attn_k_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"sage_attn_k_block_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_k_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">sage_attn_v_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"sage_attn_v_block_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_v_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">cp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># transpose q,k,v inside qkv to make kv contiguous, which is required by ring attention</span>
|
||
<span class="c1"># (b, s, 3d)</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">bs</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">seq_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># (b, s, d) -> (b, s, 2d) -> (2b, s, d)</span>
|
||
<span class="n">kv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">((</span><span class="mi">2</span> <span class="o">*</span> <span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])))</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">((</span><span class="n">query</span><span class="p">,</span> <span class="n">kv</span><span class="p">),</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">((</span><span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)))</span>
|
||
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">cp_group</span> <span class="ow">or</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_group"</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">nheads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">do_relative_attention</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span> <span class="n">remove_padding</span><span class="p">,</span> <span class="n">sage_attn</span><span class="p">,</span>
|
||
<span class="n">sage_attn_q_block_size</span><span class="p">,</span> <span class="n">sage_attn_k_block_size</span><span class="p">,</span> <span class="n">sage_attn_v_block_size</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">,</span> <span class="n">cp_group</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"padding_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">max_input_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for remove padding mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">max_input_length</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># for relative attention mode</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">"padding_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected 1"</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">RopeEmbeddingUtils</span><span class="p">:</span>
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_llama3_scaling">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_llama3_scaling">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="c1"># ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L298</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">apply_llama3_scaling</span><span class="p">(</span><span class="n">inv_freqs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
|
||
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"factor"</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">)</span>
|
||
<span class="n">low_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"low_freq_factor"</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
|
||
<span class="n">high_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"high_freq_factor"</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">)</span>
|
||
<span class="n">old_context_len</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
|
||
<span class="s2">"original_max_position_embeddings"</span><span class="p">,</span> <span class="mi">8192</span><span class="p">)</span>
|
||
|
||
<span class="n">low_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">low_freq_factor</span>
|
||
<span class="n">high_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">high_freq_factor</span>
|
||
<span class="n">new_inv_freqs</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">inv_freq</span> <span class="ow">in</span> <span class="n">inv_freqs</span><span class="p">:</span>
|
||
<span class="n">wavelen</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="n">inv_freq</span>
|
||
<span class="k">if</span> <span class="n">wavelen</span> <span class="o"><</span> <span class="n">high_freq_wavelen</span><span class="p">:</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">wavelen</span> <span class="o">></span> <span class="n">low_freq_wavelen</span><span class="p">:</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">low_freq_wavelen</span> <span class="o">!=</span> <span class="n">high_freq_wavelen</span>
|
||
<span class="n">smooth</span> <span class="o">=</span> <span class="p">(</span><span class="n">old_context_len</span> <span class="o">/</span> <span class="n">wavelen</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span>
|
||
<span class="n">high_freq_factor</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span>
|
||
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">smooth</span><span class="p">)</span> <span class="o">*</span> <span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span> <span class="o">+</span>
|
||
<span class="n">smooth</span> <span class="o">*</span> <span class="n">inv_freq</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">new_inv_freqs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">inv_freqs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions</span><span class="p">(</span><span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_for_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="c1"># Other scaling configs that only used by certain scaling types.</span>
|
||
<span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">llama3</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">rope_scaling_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"rotary_scaling config must be provided."</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_llama3_scaling</span><span class="p">(</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">dynamic</span><span class="p">:</span>
|
||
<span class="c1"># Make sure scaling_alpha exists in rope_scaling</span>
|
||
<span class="c1"># Ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8/blob/main/modeling_hunyuan.py#L346</span>
|
||
<span class="k">assert</span> <span class="n">rope_scaling_config</span><span class="p">[</span>
|
||
<span class="s2">"alpha"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"rope_scaling_config.alpha must be provided."</span>
|
||
<span class="n">scaling_alpha</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="p">[</span><span class="s2">"alpha"</span><span class="p">]</span>
|
||
<span class="n">adjusted_base</span> <span class="o">=</span> <span class="n">theta</span> <span class="o">*</span> <span class="p">(</span><span class="n">scaling_alpha</span><span class="o">**</span><span class="p">(</span><span class="n">dim</span> <span class="o">/</span> <span class="p">(</span><span class="n">dim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)))</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">adjusted_base</span><span class="o">**</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span>
|
||
<span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1">#np.cos(sinusoid_inp).shape = (32768, 64, 1)</span>
|
||
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_for_cogvlm_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1225</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">position_id</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">vision_start</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span>
|
||
<span class="n">num_pos</span> <span class="o">-</span> <span class="p">(</span><span class="n">vision_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="p">])</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">position_id</span><span class="p">,</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_long_rope_for_attention_plugin</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_orig_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">short_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_calc_mscale</span><span class="p">(</span><span class="n">scale</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale</span> <span class="o"><=</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">1.0</span>
|
||
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">num_orig_pos</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="n">short_mscale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">short_mscale</span> <span class="o">=</span> <span class="n">_calc_mscale</span><span class="p">(</span><span class="n">num_pos</span> <span class="o">/</span> <span class="n">num_orig_pos</span><span class="p">)</span>
|
||
<span class="n">long_mscale</span> <span class="o">=</span> <span class="n">short_mscale</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_compute_sinusoidal_positions</span><span class="p">(</span><span class="n">scale_factors</span><span class="p">,</span> <span class="n">is_short</span><span class="p">,</span>
|
||
<span class="n">for_attention_plugin</span><span class="p">):</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">scale_factors</span> <span class="o">*</span>
|
||
<span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i , j -> i j"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
|
||
<span class="n">inv_freq</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">mscale</span> <span class="o">=</span> <span class="n">short_mscale</span> <span class="k">if</span> <span class="n">is_short</span> <span class="k">else</span> <span class="n">long_mscale</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">concat</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="n">mscale</span>
|
||
|
||
<span class="c1"># gpt attention plugins also need inv_freq.</span>
|
||
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">concat</span>
|
||
|
||
<span class="k">return</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">,</span>
|
||
<span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="kc">True</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">scaling_long_factors</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span> <span class="n">short_mscale</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_long_rope</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">original_max_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">short_factor</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">long_factor</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
|
||
<span class="n">max_seq_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="n">short_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">short_factor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">long_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">long_factor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span>
|
||
<span class="n">t_pos</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">([</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">original_max_pos</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Choose proper freqs based on max_seq_len.</span>
|
||
<span class="n">factor</span> <span class="o">=</span> <span class="n">long_factor</span> <span class="k">if</span> <span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">max_seq_len</span> <span class="o">></span> <span class="n">original_max_pos</span> <span class="k">else</span> <span class="n">short_factor</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">inv_freq</span> <span class="o">/</span> <span class="n">factor</span>
|
||
<span class="n">freqs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i,j->ij"</span><span class="p">,</span> <span class="n">t_pos</span><span class="p">,</span> <span class="n">inv_freq</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">freqs</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)[</span><span class="o">...</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Apply scaling</span>
|
||
<span class="n">scale</span> <span class="o">=</span> <span class="n">num_pos</span> <span class="o">/</span> <span class="n">original_max_pos</span>
|
||
<span class="k">if</span> <span class="n">scale</span> <span class="o"><=</span> <span class="mf">1.0</span><span class="p">:</span>
|
||
<span class="n">scaling_factor</span> <span class="o">=</span> <span class="mf">1.0</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">scaling_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">original_max_pos</span><span class="p">))</span>
|
||
|
||
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)</span> <span class="o">*</span> <span class="n">scaling_factor</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)</span> <span class="o">*</span> <span class="n">scaling_factor</span><span class="p">),</span>
|
||
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_fake_weight">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_fake_weight">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_fake_weight</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">half</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="c1"># Note: When not using deepseek_yarn, make sure to set mscale_all_dim to 0.0.</span>
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_yarn">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_yarn">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_yarn</span><span class="p">(</span>
|
||
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10000</span><span class="p">,</span>
|
||
<span class="n">scaling_factor</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">original_max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
|
||
<span class="n">beta_fast</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
|
||
<span class="n">beta_slow</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">mscale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">mscale_all_dim</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">duplicate_data</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
|
||
|
||
<span class="c1"># Copy from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py</span>
|
||
<span class="c1"># Inverse dim formula to find dim based on number of rotations</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_find_correction_dim</span><span class="p">(</span><span class="n">num_rotations</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_position_embeddings</span> <span class="o">/</span>
|
||
<span class="p">(</span><span class="n">num_rotations</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)))</span> <span class="o">/</span> <span class="p">(</span>
|
||
<span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">base</span><span class="p">))</span>
|
||
|
||
<span class="c1"># Find dim range bounds based on rotations</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_find_correction_range</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span> <span class="n">high_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">):</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span>
|
||
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">))</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
|
||
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">high_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">))</span>
|
||
<span class="k">if</span> <span class="n">low</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">low</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">high</span> <span class="o">></span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">high</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="c1"># Clamp values just in case</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_get_mscale</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span> <span class="n">mscale</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">scale</span> <span class="o"><=</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mf">1.0</span>
|
||
<span class="k">return</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">mscale</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.0</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_linear_ramp_mask</span><span class="p">(</span><span class="nb">min</span><span class="p">,</span> <span class="nb">max</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">min</span> <span class="o">==</span> <span class="nb">max</span><span class="p">:</span>
|
||
<span class="nb">max</span> <span class="o">+=</span> <span class="mf">0.001</span> <span class="c1"># Prevent singularity</span>
|
||
|
||
<span class="n">linear_func</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="nb">max</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span>
|
||
<span class="n">ramp_func</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">linear_func</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ramp_func</span>
|
||
|
||
<span class="n">pos_freqs</span> <span class="o">=</span> <span class="n">base</span><span class="o">**</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">freq_extra</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">pos_freqs</span>
|
||
<span class="n">freq_inter</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">scaling_factor</span> <span class="o">*</span> <span class="n">pos_freqs</span><span class="p">)</span>
|
||
|
||
<span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="o">=</span> <span class="n">yarn_find_correction_range</span><span class="p">(</span>
|
||
<span class="n">beta_fast</span><span class="p">,</span>
|
||
<span class="n">beta_slow</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">,</span>
|
||
<span class="n">base</span><span class="p">,</span>
|
||
<span class="n">original_max_position_embeddings</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">inv_freq_mask</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">yarn_linear_ramp_mask</span><span class="p">(</span><span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span> <span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">))</span>
|
||
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">freq_inter</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">inv_freq_mask</span><span class="p">)</span> <span class="o">+</span> <span class="n">freq_extra</span> <span class="o">*</span> <span class="n">inv_freq_mask</span>
|
||
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"i,j -> ij"</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">inv_freq</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
|
||
<span class="n">_mscale</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span>
|
||
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale</span><span class="p">)</span> <span class="o">/</span>
|
||
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale_all_dim</span><span class="p">))</span>
|
||
|
||
<span class="k">if</span> <span class="n">duplicate_data</span><span class="p">:</span>
|
||
<span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">sinusoid_inp</span><span class="p">,</span> <span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">emb</span> <span class="o">=</span> <span class="n">sinusoid_inp</span>
|
||
|
||
<span class="n">concat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">),</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">inv_freq</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_every_two">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_every_two">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rotate_every_two</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
|
||
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">view</span><span class="p">(</span>
|
||
<span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]))</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_half">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_half">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rotate_half</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="c1"># [bs, num_attention_kv_heads, seqlen, attention_head_size]</span>
|
||
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
|
||
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">x</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span> <span class="ow">or</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">long_rope</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
|
||
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">sin</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">cos</span><span class="p">,</span> <span class="n">cos</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span>
|
||
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
|
||
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">sin</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">cos</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_every_two</span>
|
||
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span>
|
||
<span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span> <span class="o">=</span> <span class="n">position_embedding</span>
|
||
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
|
||
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
|
||
<span class="n">x_part0</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">x_part1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
|
||
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
|
||
|
||
<span class="n">y_part0</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part0</span> <span class="o">*</span>
|
||
<span class="n">cos0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part0</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin0</span><span class="p">)</span>
|
||
<span class="n">y_part1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part1</span> <span class="o">*</span>
|
||
<span class="n">cos1</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part1</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin1</span><span class="p">)</span>
|
||
|
||
<span class="n">result</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">y_part0</span><span class="p">,</span> <span class="n">y_part1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">result</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">))</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'The PositionEmbeddingType is not RoPE'</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">tensor</span> <span class="o">*</span> <span class="n">cos</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">rotate_func</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb_chatglm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
|
||
<span class="n">half_head_size</span> <span class="o">=</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="mi">3</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">embedding_weight</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
|
||
<span class="p">[</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
|
||
<span class="p">],</span>
|
||
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">embedding_weight</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">query</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">embedding_weight</span><span class="p">)</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span> <span class="n">block_embedding</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
<span class="n">sin0</span><span class="p">,</span> <span class="n">cos0</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
<span class="n">sin1</span><span class="p">,</span> <span class="n">cos1</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">block_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
|
||
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">half_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="p">[</span><span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span><span class="p">]</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">query</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">key</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">qkv</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">[docs]</a>
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb_cogvlm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
|
||
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="mi">3</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">batch_size</span><span class="p">,</span>
|
||
<span class="n">seqlen</span><span class="p">,</span>
|
||
<span class="n">num_attention_heads</span><span class="p">,</span>
|
||
<span class="n">attention_head_size</span><span class="p">,</span>
|
||
<span class="p">])</span>
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span> <span class="c1"># [max_position_embeddings, attention_head_size]</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># [1, seqlen]</span>
|
||
|
||
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># float32</span>
|
||
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span>
|
||
<span class="n">position_embedding</span><span class="p">,</span>
|
||
<span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size]</span>
|
||
<span class="n">sin</span><span class="p">,</span> <span class="n">cos</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size//2]</span>
|
||
|
||
<span class="n">input_dtype</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_query</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
|
||
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
|
||
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_key</span><span class="p">,</span>
|
||
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
|
||
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
|
||
|
||
<span class="n">query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_query</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
<span class="n">key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_key</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
|
||
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">qkv</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gpt_attention">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gpt_attention">[docs]</a>
|
||
<span class="nd">@gw</span><span class="o">.</span><span class="n">record_signature</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gpt_attention</span><span class="p">(</span>
|
||
<span class="o">*</span><span class="p">,</span>
|
||
<span class="n">qkv</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">past_key_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">attention_packed_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">sequence_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">hidden_size_per_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
|
||
<span class="n">attn_logit_softcapping_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_long_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_original_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
|
||
<span class="n">position_embedding_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span>
|
||
<span class="n">learned_absolute</span><span class="p">,</span>
|
||
<span class="n">rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_quant_orig_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">attention_output_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">attention_output_sf_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_mode</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">QuantMode</span><span class="p">]</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">max_context_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mask_type</span><span class="p">:</span> <span class="n">AttentionMaskType</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
|
||
<span class="n">block_sparse_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
|
||
<span class="n">block_sparse_homo_head_pattern</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">block_sparse_num_local_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,</span>
|
||
<span class="n">block_sparse_vertical_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
|
||
<span class="n">alibi_slopes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_mapping</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">do_cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">cross_kv</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">cross_kv_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
|
||
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">logn_scaling</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for logn scaling</span>
|
||
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="c1"># for relative attention</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">qkv_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">use_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_is_generation_length_variable</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_max_generation_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_generation_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_position_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_use</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_rope_rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">long_rope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mrope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">mrope_position_deltas</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">is_mla_enabled_flag</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">q_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">kv_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">qk_nope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">qk_rope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">v_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">q_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">kv_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">k_b_proj_trans</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">skip_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">cp_group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">num_kv_heads_origin</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation that performs the multi-head attention in GPT-like models.</span>
|
||
|
||
<span class="sd"> The signature of the function will change in the future release - we are in</span>
|
||
<span class="sd"> the process of simplifying the API. The current version is still</span>
|
||
<span class="sd"> work-in-progress! The following API is provided with hints regarding the</span>
|
||
<span class="sd"> arguments that are likely to be removed or merged with others in the future</span>
|
||
<span class="sd"> release.</span>
|
||
|
||
<span class="sd"> See docs/source/advanced/gpt-attention.md for the documentation of that function.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> qkv: Tensor (On GPU)</span>
|
||
<span class="sd"> The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, qkv_dim] in padded mode and [num_tokens, qkv_dim] in</span>
|
||
<span class="sd"> packed mode. Where qkv_dim depends on using MQA, GQA, or MHA. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> past_key_value: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores KV cache data. Its shape is</span>
|
||
<span class="sd"> [max_batch_size * max_beam_width, 2, num_kv_heads, max_seqlen, hidden_dim_per_head]</span>
|
||
<span class="sd"> in contiguous mode and</span>
|
||
<span class="sd"> [max_blocks, 2, num_kv_heads, num_tokens_per_block, hidden_dim_per_head]</span>
|
||
<span class="sd"> in paged mode. See KV Cache in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attention_mask: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the attention mask for unfused MHA or MMHA.</span>
|
||
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen].</span>
|
||
|
||
<span class="sd"> attention_packed_mask: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the packed custom mask for fmha.</span>
|
||
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen / 32], where each bit represents one mask position.</span>
|
||
|
||
<span class="sd"> sequence_lengths: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the length of each sequence. Its shape is</span>
|
||
<span class="sd"> [batch_size]. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> host_past_key_value_lengths: Tensor (On CPU)</span>
|
||
<span class="sd"> An INT32 tensor of shape [batch_size],</span>
|
||
|
||
<span class="sd"> host_max_attention_window_sizes: Tensor (On CPU)</span>
|
||
<span class="sd"> An INT32 tensor of shape [1].</span>
|
||
<span class="sd"> by default, the max_attention_window_size is determined by the shape of cache_indir_table.</span>
|
||
<span class="sd"> And we support independent max_attention_window_size for each layer.</span>
|
||
<span class="sd"> This controls the sliding-window-attention/cyclic-kv-cache features.</span>
|
||
|
||
<span class="sd"> context_lengths: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor that stores the context-phase sequence length of each request. Its shape</span>
|
||
<span class="sd"> is [batch_size]. See QKV Input in doc/functional.py,</span>
|
||
|
||
<span class="sd"> cache_indirection: Tensor (On GPU)</span>
|
||
<span class="sd"> The tensor to reconstruct the paths when using beam-search. Its</span>
|
||
<span class="sd"> shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> layer_idx: int</span>
|
||
<span class="sd"> The index of this attention layer, used to access kv_cache_block_offsets,</span>
|
||
|
||
<span class="sd"> num_heads: int</span>
|
||
<span class="sd"> The number of heads,</span>
|
||
|
||
<span class="sd"> num_kv_heads: int</span>
|
||
<span class="sd"> The number of KV heads, generic to handle MHA/MQA/GQA,</span>
|
||
|
||
<span class="sd"> hidden_size_per_head: int</span>
|
||
<span class="sd"> The hidden size per head,</span>
|
||
|
||
<span class="sd"> q_scaling: float</span>
|
||
<span class="sd"> The value used to compute the scaling factor applied to the output</span>
|
||
<span class="sd"> of the Q*K^T product. See Scaling Factors in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attn_logit_softcapping_scale: float</span>
|
||
<span class="sd"> The scale * tanh(value / scale) used to compute the scaling factor applied to the output</span>
|
||
<span class="sd"> of the Q*K^T product.</span>
|
||
|
||
<span class="sd"> rotary_embedding_dim: int</span>
|
||
<span class="sd"> The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_base: float</span>
|
||
<span class="sd"> The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale_type: RotaryScalingType</span>
|
||
<span class="sd"> The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Possible rotary scaling type:</span>
|
||
<span class="sd"> * RotaryScalingType.none</span>
|
||
<span class="sd"> * RotaryScalingType.linear</span>
|
||
<span class="sd"> * RotaryScalingType.dynamic</span>
|
||
<span class="sd"> * RotaryScalingType.longrope</span>
|
||
<span class="sd"> * RotaryScalingType.llama3</span>
|
||
|
||
<span class="sd"> rotary_embedding_scale: float</span>
|
||
<span class="sd"> The scale value to use for linear/dynamic scaling in RoPE.</span>
|
||
<span class="sd"> Ignored when position_embedding_type is not RoPE.</span>
|
||
<span class="sd"> Must be set to 1 (default) if rotary_embedding_scale_type is `none`.</span>
|
||
|
||
<span class="sd"> rotary_inv_freq: float Tensor</span>
|
||
<span class="sd"> The rotary inv freq with shape [head_size / 2].</span>
|
||
|
||
<span class="sd"> rotary_cos_sin: float2(cos/sin) Tensor</span>
|
||
<span class="sd"> The rotary cos/sin cache, which will be reused among different requests.</span>
|
||
<span class="sd"> It is taken as constant tensor.</span>
|
||
|
||
<span class="sd"> rotary_embedding_max_positions: int</span>
|
||
<span class="sd"> Needed only for `dynamic` RoPE scaling. Ignored otherwise.</span>
|
||
|
||
<span class="sd"> position_embedding_type: PositionEmbeddingType</span>
|
||
<span class="sd"> The position embedding type:</span>
|
||
<span class="sd"> * PositionEmbeddingType.learned_absolute</span>
|
||
<span class="sd"> * PositionEmbeddingType.relative</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gptj</span>
|
||
<span class="sd"> * PositionEmbeddingType.rope_gpt_neox</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi</span>
|
||
<span class="sd"> * PositionEmbeddingType.alibi_with_scale</span>
|
||
|
||
<span class="sd"> kv_orig_quant_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for quantization to INT8/FP8</span>
|
||
<span class="sd"> in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> kv_quant_orig_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for dequantization from</span>
|
||
<span class="sd"> INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> attention_output_orig_quant_scale: Tensor</span>
|
||
<span class="sd"> The tensor to store the scaling factor for quantization to FP8</span>
|
||
<span class="sd"> in the KV cache. Its shape is [1].</span>
|
||
|
||
<span class="sd"> kv_cache_quant_mode: QuantMode (int flags)</span>
|
||
<span class="sd"> Do we enable the INT8 or FP8 KV cache?</span>
|
||
|
||
<span class="sd"> max_context_length: int32_t</span>
|
||
<span class="sd"> The length of the longest input sequence. See QKV Input in</span>
|
||
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> mask_type: int = 1</span>
|
||
<span class="sd"> The type of mask:</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.padding for BERT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.causal for GPT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.sliding_window_causal for GPT,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM-6B,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectionalglm for GLM-10B,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.blocksparse for Phi-3-small,</span>
|
||
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.custom_mask for any models.</span>
|
||
|
||
<span class="sd"> block_sparse_block_size: int</span>
|
||
<span class="sd"> Block size in block sparse attention</span>
|
||
|
||
<span class="sd"> block_sparse_homo_head_pattern: bool</span>
|
||
<span class="sd"> Do all attention heads share same vertical stride pattern?</span>
|
||
|
||
<span class="sd"> block_sparse_num_local_blocks: int</span>
|
||
<span class="sd"> Number of active blocks near diagonal</span>
|
||
|
||
<span class="sd"> block_sparse_vertical_stride: int</span>
|
||
<span class="sd"> Stride of active blocks in vertical dimension</span>
|
||
|
||
<span class="sd"> alibi_slopes: Tensor</span>
|
||
<span class="sd"> The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel</span>
|
||
<span class="sd"> when possible,</span>
|
||
|
||
<span class="sd"> tp_size: int</span>
|
||
<span class="sd"> The number of processes/GPUs when tensor parallelism is activated,</span>
|
||
|
||
<span class="sd"> tp_rank: int</span>
|
||
<span class="sd"> The rank of that process (when running tensor parallelism),</span>
|
||
|
||
<span class="sd"> kv_cache_block_offsets:</span>
|
||
<span class="sd"> The tensor of block offsets for the KV cache. Its shape is</span>
|
||
<span class="sd"> [num_layers, max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2],</span>
|
||
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_block_offsets:</span>
|
||
<span class="sd"> The same as kv_cache_block_offsets, but on cpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_pool_pointers:</span>
|
||
<span class="sd"> The tensor of pool pointers for the KV cache. Its shape is [num_layers, 2],</span>
|
||
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
|
||
|
||
<span class="sd"> host_kv_cache_pool_mapping:</span>
|
||
<span class="sd"> The tensor of pool mapping for the different memory pools. Its shape is [num_layers,2] - for each layer, the index of the pool, and the index of the layer within the pool,</span>
|
||
|
||
<span class="sd"> do_cross_attention: bool = False</span>
|
||
<span class="sd"> Do we use this as cross attention instead of self attention,</span>
|
||
|
||
<span class="sd"> cross_kv: Tensor = None</span>
|
||
<span class="sd"> The KV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 2 * kvHeadNum * headSize] in padded mode and [1, num_tokens, 2 * kvHeadNum * headSize] in</span>
|
||
<span class="sd"> packed mode,</span>
|
||
|
||
<span class="sd"> cross_kv_length: Tensor = None</span>
|
||
<span class="sd"> The length of the longest encoder output sequence,</span>
|
||
|
||
<span class="sd"> encoder_input_lengths: Tensor</span>
|
||
<span class="sd"> The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],</span>
|
||
|
||
<span class="sd"> logn_scaling: Tensor = None</span>
|
||
<span class="sd"> The logn scaling tensor [max_position_embedding_len], which is applied to q in order to help extrapolation</span>
|
||
|
||
<span class="sd"> relative_attention_bias: Tensor = None</span>
|
||
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
|
||
|
||
<span class="sd"> max_distance: int = 0</span>
|
||
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
|
||
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
|
||
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
|
||
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> qkv_bias: Tensor = None,</span>
|
||
<span class="sd"> The qkv bias tensor.</span>
|
||
|
||
<span class="sd"> use_cache: bool = False</span>
|
||
<span class="sd"> Do we need to store kv cache ? not needed if there is no generation phase.</span>
|
||
|
||
<span class="sd"> spec_decoding_is_generation_length_variable: bool = False,</span>
|
||
<span class="sd"> Whether the generation lengths can be different for each sequence in a batch.</span>
|
||
<span class="sd"> For Medusa, this should be set False.</span>
|
||
<span class="sd"> For Redrafter, this should be set to True.</span>
|
||
|
||
<span class="sd"> spec_decoding_max_generation_length: int = 1,</span>
|
||
<span class="sd"> The maximum number of tokens possible in the generation phase per sequence.</span>
|
||
|
||
<span class="sd"> spec_decoding_generation_lengths: Tensor = None,</span>
|
||
<span class="sd"> The generation phase tokens' lengths for each sequence.</span>
|
||
<span class="sd"> Shape: [batch_size]</span>
|
||
|
||
<span class="sd"> spec_decoding_position_offsets: Tensor = None,</span>
|
||
<span class="sd"> The speculative decoding tokens's position offsets (shared by all sequences).</span>
|
||
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1].</span>
|
||
|
||
<span class="sd"> spec_decoding_packed_mask: Tensor = None,</span>
|
||
<span class="sd"> The speculative decoding tokens's attention mask (packed into uint32_t bits).</span>
|
||
<span class="sd"> remove_input_padding is False:</span>
|
||
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1, divUp(num_draft_tokens + 1, 32)].</span>
|
||
<span class="sd"> remove_input_padding is True:</span>
|
||
<span class="sd"> Shape: [sum(spec_decoding_generation_lengths), divUp(num_draft_tokens + 1, 32)].</span>
|
||
|
||
<span class="sd"> long_rope_rotary_inv_freq: float Tensor</span>
|
||
<span class="sd"> Additional rotary inv freq used for longer sequence lengths. Shape: [head_size / 2]</span>
|
||
|
||
<span class="sd"> long_rope_rotary_cos_sin: float2(cos/sin) Tensor</span>
|
||
<span class="sd"> Additional rotary cos/sin cache used for longer sequence lengths.</span>
|
||
|
||
<span class="sd"> is_mla_enable: bool = False</span>
|
||
<span class="sd"> Do we need to enable deepseekv2 mla?</span>
|
||
|
||
<span class="sd"> host_runtime_perf_knobs: Tensor = None,</span>
|
||
<span class="sd"> The runtime perf knobs bit mask, controls whether to use certain perf knob in the runtime.</span>
|
||
|
||
<span class="sd"> host_context_progress: Tensor = None,</span>
|
||
<span class="sd"> The structure used to track layer-wise progress in context phase.</span>
|
||
|
||
<span class="sd"> skip_attn: Tensor = None,</span>
|
||
<span class="sd"> A bool tensor on CPU. If it is true, don't run attention plugin, returning directly.</span>
|
||
|
||
<span class="sd"> num_kv_heads_origin: int</span>
|
||
<span class="sd"> The origin number of KV heads, without the process of TP</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">mrope_rotary_cos_sin</span>
|
||
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_mrope</span><span class="p">())</span>
|
||
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'GPTAttention'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">host_max_attention_window_sizes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">host_sink_token_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">paged_kv_cache_flag</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span>
|
||
<span class="k">if</span> <span class="n">do_cross_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="k">pass</span>
|
||
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">0</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_kv_heads_origin</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">num_kv_heads_origin</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
|
||
|
||
<span class="n">unfuse_qkv_gemm</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"unfuse_qkv_gemm"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">is_unfuse_qkv_gemm</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"layer_idx"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">layer_idx</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_heads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">vision_start</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"vision_start"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_start</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">vision_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"vision_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_kv_heads"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_kv_heads_origin</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"num_kv_heads_origin"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads_origin</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"head_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">hidden_size_per_head</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">unidirectional</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"unidirectional"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_scaling"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">attn_logit_softcapping_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"attn_logit_softcapping_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_base</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_base"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_base</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_short_m_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_long_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_long_m_scale"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_max_positions"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">rotary_embedding_original_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"rotary_embedding_original_max_positions"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"position_embedding_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">position_embedding_type</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"context_fmha_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">is_spec_decoding_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_spec_decoding_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">spec_decoding_is_generation_length_variable</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"spec_decoding_is_generation_length_variable"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_is_generation_length_variable</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">spec_decoding_max_generation_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"spec_decoding_max_generation_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">is_mla_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_mla_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">is_mla_enabled_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">q_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"q_lora_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">kv_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"kv_lora_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">qk_nope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"qk_nope_head_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">qk_rope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"qk_rope_head_dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">v_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"v_head_dim"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="c1"># reset mask_type to custom_mask.</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="c1"># context fmha needs packed mask.</span>
|
||
<span class="k">assert</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o"><</span> <span class="mi">100</span><span class="p">:</span>
|
||
<span class="n">mask_type</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span>
|
||
|
||
<span class="n">mask_type_filed</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"mask_type"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">mask_type</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_block_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_block_size</span><span class="p">],</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_homo_head_pattern</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_homo_head_pattern"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">block_sparse_homo_head_pattern</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">block_sparse_num_local_blocks</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_num_local_blocks"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_num_local_blocks</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_sparse_vertical_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"block_sparse_vertical_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_vertical_stride</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"tp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">QuantModeWrapper</span><span class="p">):</span>
|
||
<span class="c1"># Now in TRT-LLM only use global kv_cache, so it's enough to get the first quant mode from list</span>
|
||
<span class="n">kv_cache_quant_mode</span> <span class="o">=</span> <span class="n">kv_cache_quant_mode</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="n">kv_cache_quant_mode_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_quant_mode"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_kv_cache"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">paged_kv_cache_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"tokens_per_block"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_context_length"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pos_shift_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"pos_shift_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">dense_context_fmha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"dense_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"qkv_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">do_cross_attention_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"do_cross_attention"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">do_cross_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_distance"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">use_paged_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_paged_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">use_fp8_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_fp8_context_fmha"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">has_full_attention_mask_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"has_full_attention_mask"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">use_cache_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"use_cache"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">use_cache</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">fuse_fp4_quant</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fuse_fp4_quant</span>
|
||
<span class="n">fuse_fp4_quant_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"fuse_fp4_quant"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">fuse_fp4_quant</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">skip_attn_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"skip_attn"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
|
||
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_group"</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"use_logn_scaling"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">use_logn_scaling</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">layer_idx</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">vision_start</span><span class="p">,</span> <span class="n">vision_length</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span>
|
||
<span class="n">num_kv_heads_origin</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">unidirectional</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span>
|
||
<span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">position_embedding_type</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">rotary_embedding_base</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">rotary_embedding_scale</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span>
|
||
<span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">unfuse_qkv_gemm</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
|
||
<span class="n">kv_cache_quant_mode_field</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">mask_type_filed</span><span class="p">,</span>
|
||
<span class="n">block_sparse_block_size</span><span class="p">,</span> <span class="n">block_sparse_homo_head_pattern</span><span class="p">,</span>
|
||
<span class="n">block_sparse_num_local_blocks</span><span class="p">,</span> <span class="n">block_sparse_vertical_stride</span><span class="p">,</span>
|
||
<span class="n">paged_kv_cache</span><span class="p">,</span> <span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
|
||
<span class="n">qkv_bias_enabled</span><span class="p">,</span> <span class="n">do_cross_attention_field</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span>
|
||
<span class="n">pos_shift_enabled</span><span class="p">,</span> <span class="n">dense_context_fmha</span><span class="p">,</span> <span class="n">use_paged_context_fmha_field</span><span class="p">,</span>
|
||
<span class="n">use_fp8_context_fmha_field</span><span class="p">,</span> <span class="n">has_full_attention_mask_field</span><span class="p">,</span> <span class="n">use_cache_pf</span><span class="p">,</span>
|
||
<span class="n">is_spec_decoding_enabled</span><span class="p">,</span> <span class="n">spec_decoding_is_generation_length_variable</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">is_mla_enabled</span><span class="p">,</span> <span class="n">q_lora_rank</span><span class="p">,</span>
|
||
<span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">v_head_dim</span><span class="p">,</span>
|
||
<span class="n">fuse_fp4_quant_pf</span><span class="p">,</span> <span class="n">skip_attn_pf</span><span class="p">,</span> <span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span>
|
||
<span class="n">use_logn_scaling</span>
|
||
<span class="p">])</span>
|
||
|
||
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"causal_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">attn_plug</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">qkv</span><span class="p">]</span> <span class="k">if</span> <span class="n">is_unfuse_qkv_gemm</span> <span class="k">else</span> <span class="p">[</span><span class="n">qkv</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">mask_type</span> <span class="o">==</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span><span class="p">:</span>
|
||
<span class="c1"># useFullCustomMask</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_mask</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o"><</span> <span class="mi">100</span><span class="p">:</span>
|
||
<span class="c1"># usePackedCustomMask</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_packed_mask</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">sequence_length</span><span class="p">,</span>
|
||
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">cache_indirection</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
|
||
<span class="n">host_sink_token_length</span><span class="p">,</span>
|
||
<span class="n">context_lengths</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the kv_cache_block_offsets tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_block_offsets tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_pool_pointers</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_pool_pointers tensor shall not be None"</span>
|
||
<span class="k">assert</span> <span class="n">host_kv_cache_pool_mapping</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Paged kv cache is enabled, the host_kv_cache_pool_mapping tensor shall not be None"</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">kv_cache_block_offsets</span><span class="p">,</span> <span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
|
||
<span class="n">host_kv_cache_pool_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_pool_mapping</span>
|
||
<span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">past_key_value</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_orig_quant_scale</span><span class="p">,</span> <span class="n">kv_quant_orig_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">,</span> <span class="s2">"FP8 Context FMHA needs to be enabled"</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_output_orig_quant_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">attention_output_sf_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"attention_output_sf_scale must be provided when fuse_fp4_quant is enabled."</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_output_sf_scale</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_inv_freq</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_cos_sin</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alibi_slopes</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">do_cross_attention</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">cross_kv</span><span class="p">,</span> <span class="n">cross_kv_length</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">qkv_bias</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># add position_ids as well only if speculative decoding mode</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_position_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_generation_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">spec_decoding_use</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">spec_decoding_generation_lengths</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
|
||
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_use</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">long_rope_rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">long_rope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">long_rope_rotary_inv_freq</span><span class="p">,</span> <span class="n">long_rope_rotary_cos_sin</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">mrope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">mrope_position_deltas</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
|
||
<span class="n">mrope_position_deltas</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">host_runtime_perf_knobs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_runtime_perf_knobs</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">host_context_progress</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_progress</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">is_mla_enabled_flag</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">q_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">kv_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">k_b_proj_trans</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">q_b_proj</span><span class="p">,</span> <span class="n">kv_b_proj</span><span class="p">,</span> <span class="n">k_b_proj_trans</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">skip_attn</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">logn_scaling</span><span class="p">]</span>
|
||
|
||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">i</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Found None input for </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2"> th item in plugin inputs </span><span class="si">{</span><span class="n">plug_inputs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">"causal_attn"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="c1"># The output scaling factor tensor.</span>
|
||
<span class="n">output_sf</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
|
||
<span class="n">output_sf</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">expected_outputs</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="n">present_key_value</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">present_key_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">expected_outputs</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="n">expected_outputs</span><span class="p">,</span> \
|
||
<span class="sa">f</span><span class="s2">"Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected </span><span class="si">{</span><span class="n">expected_outputs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(</span>
|
||
<span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
|
||
<span class="c1"># past key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="c1"># present key value</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">output_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">output_sf</span><span class="p">),</span> <span class="n">present_key_value</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_key_value</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="assertion">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.assertion">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">assertion</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">''</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_assertion</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="layer_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.layer_norm">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
|
||
<span class="n">use_diff_of_squares</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a layer-norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the layer-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the layer-norm formula and</span>
|
||
<span class="sd"> 'bias' is 'beta'. The 'eps' value is added to the variance before computing</span>
|
||
<span class="sd"> the squared-root.</span>
|
||
|
||
<span class="sd"> This implementation (when using the plugin) supports an additional flag to</span>
|
||
<span class="sd"> enable/disable the use of a difference of squares ('Var = Mean(X^2) -</span>
|
||
<span class="sd"> Mean(X)^2').</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> bias : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'beta' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.</span>
|
||
|
||
<span class="sd"> use_diff_of_squares : bool</span>
|
||
<span class="sd"> Does the plugin use the difference of squares to compute the</span>
|
||
<span class="sd"> variance?</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="c1"># FIXME: better way?</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
|
||
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
|
||
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">axes_mask</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rms_norm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rms_norm">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rms_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
|
||
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a RMS norm operation on a tensor.</span>
|
||
|
||
<span class="sd"> That operation applies the rms-normalization to its input tensor. In its</span>
|
||
<span class="sd"> simplest form, for large language models, the 'normalized_shape' should be</span>
|
||
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
|
||
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
|
||
<span class="sd"> right-most dimension).</span>
|
||
|
||
<span class="sd"> The 'weight' tensor corresponds to 'gamma' in the rms-norm formula.</span>
|
||
<span class="sd"> The 'eps' value is added to the variance before computing the squared-root.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The tensor to normalize.</span>
|
||
|
||
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
|
||
<span class="sd"> The shape of the sub-tensor that is normalized. Use 'hidden_dim' to</span>
|
||
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
|
||
|
||
<span class="sd"> num_groups: int = 1</span>
|
||
<span class="sd"> The group size.</span>
|
||
|
||
<span class="sd"> weight : Optional[Tensor] = None</span>
|
||
<span class="sd"> The 'gamma' term in layer-norm. Its shape must be</span>
|
||
<span class="sd"> 'normalized_shape'.</span>
|
||
|
||
<span class="sd"> eps : float</span>
|
||
<span class="sd"> The epsilon term to be added to the variance in the squared-root.weig</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of that operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="o">-</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">))])</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
|
||
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
|
||
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">+</span>
|
||
<span class="p">[</span><span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">])</span>
|
||
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">with</span> <span class="n">precision</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">):</span>
|
||
<span class="n">input_dtype</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span>
|
||
<span class="n">fp32_input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span>
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="n">fp32_input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
|
||
|
||
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
|
||
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
|
||
<span class="n">fp32_y</span> <span class="o">=</span> <span class="n">fp32_input</span> <span class="o">/</span> <span class="n">denom</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_y</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span>
|
||
|
||
<span class="k">return</span> <span class="n">y</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rearrange">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rearrange">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rearrange</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]],</span> <span class="n">expression</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a rearrange operation on a tensor.</span>
|
||
|
||
<span class="sd"> This operation is a reader-friendly smart element reordering for multidimensional tensors,</span>
|
||
<span class="sd"> including functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,</span>
|
||
<span class="sd"> stack, concatenate and other operations. Please see: https://einops.rocks/api/rearrange/</span>
|
||
|
||
<span class="sd"> For example, if the shape of input tensor is [32, 30, 40, 3], and run:</span>
|
||
<span class="sd"> `rearrange(x, 'b (h h1) (w w1) c -> b h w 1 (c h1 w1) 1', h1=2, w1=2)`</span>
|
||
<span class="sd"> it would produce a tensor with shape as [32, 15, 20, 1, 12, 1].</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Union[Tensor, Sequence[Tensor]]</span>
|
||
<span class="sd"> If it is a tensor, it will directly operate on it.</span>
|
||
<span class="sd"> Otherwise, if it is a sequence, it will concat it to a tensor and then</span>
|
||
<span class="sd"> operates on it.</span>
|
||
|
||
<span class="sd"> expression : str</span>
|
||
<span class="sd"> The expression about how to reorder the tensor in a reader-friendly way.</span>
|
||
|
||
<span class="sd"> kwargs:</span>
|
||
<span class="sd"> Keyword arguments to set some identifiers with specific values.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output tensor of this operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">re</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_init_expression</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
|
||
<span class="n">expr_items</span> <span class="o">=</span> <span class="n">expr</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">" "</span><span class="p">)</span>
|
||
<span class="n">tmp_name_index</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">expr_items</span><span class="p">):</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">'\b\d+\b'</span><span class="p">,</span> <span class="n">item</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">values</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">prefix</span> <span class="o">=</span> <span class="s2">"("</span> <span class="k">if</span> <span class="s2">"("</span> <span class="ow">in</span> <span class="n">item</span> <span class="k">else</span> <span class="s2">""</span>
|
||
<span class="n">subfix</span> <span class="o">=</span> <span class="s2">")"</span> <span class="k">if</span> <span class="s2">")"</span> <span class="ow">in</span> <span class="n">item</span> <span class="k">else</span> <span class="s2">""</span>
|
||
<span class="n">expr_items</span><span class="p">[</span>
|
||
<span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">prefix</span><span class="si">}</span><span class="s2">NumericId</span><span class="si">{</span><span class="n">tmp_name_index</span><span class="si">}</span><span class="s2">Val</span><span class="si">{</span><span class="n">values</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}{</span><span class="n">subfix</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">tmp_name_index</span> <span class="o">+=</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="s2">" "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">expr_items</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_get_all_identifier</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">'\b[a-zA-Z_]+\d*\b'</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_get_all_symbols</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">'\b\w+\b'</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_get_dim_expr</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="p">[</span>
|
||
<span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">())</span>
|
||
<span class="k">for</span> <span class="n">match</span> <span class="ow">in</span> <span class="n">re</span><span class="o">.</span><span class="n">finditer</span><span class="p">(</span><span class="sa">r</span><span class="s1">'\b\w+\b|\(.*?\)'</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">src_shape_expr</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">dst_shape_expr</span> <span class="o">=</span> <span class="n">expression</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">"->"</span><span class="p">)</span>
|
||
<span class="n">unknown_identifiers</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">'[^a-zA-Z0-9_\(\)]'</span><span class="p">,</span>
|
||
<span class="n">src_shape_expr</span> <span class="o">+</span> <span class="n">dst_shape_expr</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="n">unknown_identifiers</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Unknown identifiers: </span><span class="si">{</span><span class="n">unknown_identifiers</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="n">src_identifiers</span> <span class="o">=</span> <span class="n">_get_all_identifier</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)</span>
|
||
<span class="n">dst_identifiers</span> <span class="o">=</span> <span class="n">_get_all_identifier</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">))</span>
|
||
<span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">))</span>
|
||
<span class="p">),</span> <span class="s2">"Indexing expression contains duplicate dimension."</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">set</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">)</span>
|
||
<span class="p">),</span> <span class="s2">"Identifiers only on one side of expression (should be on both)."</span>
|
||
|
||
<span class="n">new_expression</span> <span class="o">=</span> <span class="n">_init_expression</span><span class="p">(</span><span class="n">expression</span><span class="p">)</span>
|
||
<span class="n">src_shape_expr</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">dst_shape_expr</span> <span class="o">=</span> <span class="n">new_expression</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">"->"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># concat if inputs are sequence of tensors</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">):</span>
|
||
<span class="n">inputs</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="p">(</span>
|
||
<span class="n">inputs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">))</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"inputs.ndim() is </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span><span class="si">}</span><span class="s2"> while indexing expression has </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">))</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">src_symbols</span> <span class="o">=</span> <span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)</span>
|
||
<span class="n">dst_symbols</span> <span class="o">=</span> <span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">)</span>
|
||
|
||
<span class="c1"># find all the symbols-values mapping and store them in symbol_map</span>
|
||
<span class="n">symbol_map</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">symbol</span><span class="p">:</span> <span class="p">{</span>
|
||
<span class="s2">"updated"</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="s2">"value"</span><span class="p">:</span> <span class="kc">None</span>
|
||
<span class="p">}</span>
|
||
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="nb">set</span><span class="p">(</span><span class="n">src_symbols</span> <span class="o">+</span> <span class="n">dst_symbols</span><span class="p">)</span>
|
||
<span class="p">}</span>
|
||
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">symbol_map</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="s2">"NumericId"</span> <span class="ow">in</span> <span class="n">symbol</span><span class="p">:</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">symbol</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">"Val"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">for</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">dim_expr</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)):</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim_expr</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">symbol</span> <span class="o">=</span> <span class="n">dim_expr</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]:</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">divisors</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="n">unknown_symbol</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dim_expr</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]:</span>
|
||
<span class="n">unknown_symbol</span> <span class="o">=</span> <span class="n">symbol</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">divisors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="n">unknown_symbol</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">divisors</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="n">divisor</span> <span class="o">=</span> <span class="n">prod</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">divisors</span><span class="p">),</span> <span class="s2">"int64"</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">unknown_symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span>
|
||
<span class="n">idx</span><span class="p">)</span> <span class="o">/</span> <span class="n">divisor</span>
|
||
<span class="n">symbol_map</span><span class="p">[</span><span class="n">unknown_symbol</span><span class="p">][</span><span class="s2">"updated"</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="k">for</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">symbol_map</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s2">"updated"</span><span class="p">]</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">symbol</span><span class="si">}</span><span class="s2"> cannot be inferred, please set it manually"</span>
|
||
|
||
<span class="n">dst_dims</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">dim_expr</span> <span class="ow">in</span> <span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim_expr</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">dst_dims</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">dim_expr</span><span class="p">[</span><span class="mi">0</span><span class="p">]][</span><span class="s2">"value"</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">prod</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dim_expr</span><span class="p">]),</span>
|
||
<span class="s2">"int64"</span><span class="p">),</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">dst_dims</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
|
||
<span class="n">dst_dims</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">dst_dims</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
|
||
<span class="n">src_indices</span> <span class="o">=</span> <span class="p">{</span><span class="n">symbol</span><span class="p">:</span> <span class="n">idx</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)}</span>
|
||
<span class="n">permute_dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">src_indices</span><span class="p">[</span><span class="n">symbol</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dst_identifiers</span><span class="p">]</span>
|
||
|
||
<span class="n">symbol_shape</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">"value"</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">src_identifiers</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">"int64"</span><span class="p">)</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">symbol_shape</span><span class="p">)</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">permute_dims</span><span class="p">)</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">dst_dims</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="repeat">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">repeat</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">sizes</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Repeats the tensor along the specified dimensions.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The tensor to be repeated.</span>
|
||
<span class="sd"> sizes : Sequence[int]</span>
|
||
<span class="sd"> The number of times to repeat the tensor along each dimension.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor except for repeated input tensors along specified dim.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o"><=</span> <span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">),</span> \
|
||
<span class="s2">"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"</span>
|
||
<span class="n">repeated_tensor</span> <span class="o">=</span> <span class="nb">input</span>
|
||
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">repeated_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">repeated_tensor</span><span class="p">]</span> <span class="o">*</span> <span class="n">sizes</span><span class="p">[</span><span class="n">k</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="n">k</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">repeated_tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="repeat_interleave">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat_interleave">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">repeat_interleave</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">repeats</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Repeats elements of a tensor along an axis.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> repeats : int</span>
|
||
<span class="sd"> The number of repetitions along axis specified.</span>
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The dimension along which repetitions are performed.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A tensor with the same shape as input except for repeated elements along specified dim.</span>
|
||
|
||
<span class="sd"> TODO: Allow repeats to be a list of integers and dim to be unspecified.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">expanded_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tile_output_size</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">repeats</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
|
||
<span class="p">])</span>
|
||
<span class="n">tile</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">tile_output_size</span><span class="p">)</span>
|
||
<span class="n">tile_reshape_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())]</span>
|
||
<span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">repeats</span>
|
||
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tile</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">tile_reshape_size</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="n">tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="meshgrid2d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.meshgrid2d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">meshgrid2d</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Creates grids (2D) of coordinates specified by the 1D inputs (only supports `indexing=\'xy\'`).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> x : Tensor</span>
|
||
<span class="sd"> The first input (1D) tensor.</span>
|
||
<span class="sd"> y : Tensor</span>
|
||
<span class="sd"> The second input (1D) tensor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tuple of two tensors produced.</span>
|
||
|
||
<span class="sd"> TODO: Add full support for torch.meshgrid.</span>
|
||
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.meshgrid.html#torch-meshgrid</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">x</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">y</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">y</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">grid_x</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
|
||
<span class="n">grid_y</span> <span class="o">=</span> <span class="n">repeat</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">grid_x</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_logn_scaling">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_logn_scaling">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">generate_logn_scaling</span><span class="p">(</span><span class="n">seq_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8192</span><span class="p">,</span>
|
||
<span class="n">max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32768</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the Log-N scaling vector for Qwen inference extrapolation</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> seq_length : int</span>
|
||
<span class="sd"> The max seq length in training (default to 8192 in Qwen-1)</span>
|
||
<span class="sd"> max_position_embeddings : int</span>
|
||
<span class="sd"> The max position embeddings. (default to 32768 in Qwen-1)</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant np.ndarray that contains logn scaling vector</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">logn_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">></span> <span class="n">seq_length</span> <span class="k">else</span> <span class="mi">1</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_position_embeddings</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">logn_list</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_slopes">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_slopes">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">generate_alibi_slopes</span><span class="p">(</span><span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">alibi_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">alibi_bias_max</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">)</span> <span class="o">-></span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> num_heads : int</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
<span class="sd"> dtype : trt.DataType</span>
|
||
<span class="sd"> The data type of the returned slopes</span>
|
||
<span class="sd"> tp_size : int</span>
|
||
<span class="sd"> The tensor parallelism size</span>
|
||
<span class="sd"> tp_rank : int</span>
|
||
<span class="sd"> The tensor parallelism rank</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi slopes.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">num_heads</span>
|
||
|
||
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">rank_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
|
||
<span class="n">start_head_id</span> <span class="o">=</span> <span class="n">rank_heads</span> <span class="o">*</span> <span class="n">tp_rank</span>
|
||
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">start_head_id</span> <span class="o">+</span> <span class="n">rank_heads</span>
|
||
|
||
<span class="n">closest_power_of_2</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_heads</span><span class="p">))</span>
|
||
<span class="c1"># FT's implementation</span>
|
||
<span class="c1"># https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248</span>
|
||
<span class="n">slopes_ft</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">h_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_head_id</span><span class="p">,</span> <span class="n">end_head_id</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">h_id</span> <span class="o"><</span> <span class="n">closest_power_of_2</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
|
||
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">-</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span> <span class="n">h_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
|
||
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span>
|
||
<span class="p">(</span><span class="n">h_id</span> <span class="o">-</span> <span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">slopes_ft</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">alibi_scale</span> <span class="o">*</span> <span class="n">slopes</span>
|
||
<span class="n">slopes</span> <span class="o">=</span> <span class="n">slopes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">end_head_id</span> <span class="o">-</span> <span class="n">start_head_id</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">slopes</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="generate_alibi_biases">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_biases">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">generate_alibi_biases</span><span class="p">(</span><span class="n">slopes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">key_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.</span>
|
||
|
||
<span class="sd"> The ALiBi biases are added to the result of the Q*K^T product in the</span>
|
||
<span class="sd"> multi-head attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> slopes : Tensor</span>
|
||
<span class="sd"> The slopes.</span>
|
||
|
||
<span class="sd"> key_length : Tensor</span>
|
||
<span class="sd"> The size of the K vector per head.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A constant tensor that contains the ALiBi biases.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="c1"># We don't need to care about the batch size or query length since we can just broadcast</span>
|
||
<span class="c1"># across the batch and query dimensions</span>
|
||
|
||
<span class="n">trt_0</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
|
||
<span class="n">arange_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">key_length</span><span class="p">])</span>
|
||
|
||
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">trt_0</span><span class="p">,</span> <span class="n">key_length</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">arange_shape</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">slopes</span> <span class="o">*</span> <span class="n">arange_tensor</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="expand_mask">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_mask">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">expand_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Expand an attention mask.</span>
|
||
|
||
<span class="sd"> That function adds the sequence of operations to expand from a tensor of</span>
|
||
<span class="sd"> shape '[batch_size, src_seq_len]' to a tensor of shape</span>
|
||
<span class="sd"> '[batch_size, 1, tgt_seq_len, src_seq_len]'. It can be used to create the</span>
|
||
<span class="sd"> mask applied to the Q*K^T product before the softmax operation in the</span>
|
||
<span class="sd"> multi-head attention block.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> mask : Tensor</span>
|
||
<span class="sd"> The input mask</span>
|
||
|
||
<span class="sd"> tgt_len : Optional[Tensor]</span>
|
||
<span class="sd"> The dimension of the 3rd dimension in the output tensor. If None,</span>
|
||
<span class="sd"> the 2nd dimension of the input is used.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">bsz</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">src_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
|
||
<span class="n">tgt_len</span> <span class="o">=</span> <span class="n">tgt_len</span> <span class="k">if</span> <span class="n">tgt_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">src_len</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
|
||
<span class="n">mask</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">'-inf'</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">mask</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="gather_last_token_logits">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_last_token_logits">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">gather_last_token_logits</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Extract the logits that correspond to the last token from the hidden states.</span>
|
||
|
||
<span class="sd"> That function adds the operations to extract the logits of the last tokens</span>
|
||
<span class="sd"> in a batch of sequences.</span>
|
||
|
||
<span class="sd"> Depending on whether 'remove_input_padding' is 'True' or 'False', that</span>
|
||
<span class="sd"> function assumes inputs of different shapes.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'True', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be packed. It has a shape '[num_tokens, hidden_dim]' where</span>
|
||
<span class="sd"> 'num_tokens' is the sum of the lengths of the sequences in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_tokens_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the inclusive prefix-sums of the lengths of the sequences in</span>
|
||
<span class="sd"> the batch.</span>
|
||
|
||
<span class="sd"> When 'remove_input_padding' is 'False', the 'hidden_states' tensor is</span>
|
||
<span class="sd"> assumed to be padded. It has a shape '[batch_size, max_seqlen, hidden_dim]'</span>
|
||
<span class="sd"> where 'max_seqlen' is the length of the longest sequence in the batch and</span>
|
||
<span class="sd"> 'hidden_dim' is the hidden dimension. The 'last_token_ids' is a 1D tensor</span>
|
||
<span class="sd"> that encodes the length of each sequence in the batch.</span>
|
||
|
||
<span class="sd"> In both cases, that function produces a tensor of shape '[batch_size,</span>
|
||
<span class="sd"> hidden_size]' where the row at index 'i' corresponds to the logits of the</span>
|
||
<span class="sd"> last token from the 'i'-th sequence.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> hidden_states : Tensor</span>
|
||
<span class="sd"> The hidden states</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> remove_input_padding : bool</span>
|
||
<span class="sd"> Indicate if the hidden_states are packed ('True') or padded</span>
|
||
<span class="sd"> ('False').</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor created by that sequence of operations.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">last_token_ids</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span>
|
||
|
||
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># [seq_len, hidden]</span>
|
||
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">ndim</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="c1"># only calculate logits for the last token</span>
|
||
<span class="c1"># [batch_size, seqlen, hidden_size] -> [batch_size, hidden_size]</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span>
|
||
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
|
||
<span class="k">elif</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="c1"># speculative decoding needs last few token's logits</span>
|
||
<span class="c1"># last_token_ids is of shape [batch_size, num_last_tokens]</span>
|
||
<span class="c1"># So [batch_size, seqlen, hidden_size] -> [batch_size, num_last_tokens, hidden_size]</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
|
||
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">]))</span>
|
||
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="n">concat</span><span class="p">([</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
|
||
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
|
||
<span class="p">]))</span>
|
||
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">hidden_states</span></div>
|
||
|
||
|
||
|
||
<span class="n">ACT2FN</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'relu'</span><span class="p">:</span> <span class="n">relu</span><span class="p">,</span>
|
||
<span class="s1">'tanh'</span><span class="p">:</span> <span class="n">tanh</span><span class="p">,</span>
|
||
<span class="s1">'gelu'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_new'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_fast'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'gelu_pytorch_tanh'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'openai-gelu'</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="n">geglu</span><span class="p">,</span>
|
||
<span class="s1">'gegelu'</span><span class="p">:</span> <span class="n">gegelu</span><span class="p">,</span>
|
||
<span class="s1">'identity'</span><span class="p">:</span> <span class="n">identity</span><span class="p">,</span>
|
||
<span class="s1">'silu'</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
|
||
<span class="s1">'softplus'</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
|
||
<span class="s1">'relu2'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||
<span class="s1">'squared-relu'</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
|
||
<span class="s1">'sigmoid'</span><span class="p">:</span> <span class="n">sigmoid</span><span class="p">,</span>
|
||
<span class="s1">'quick_gelu'</span><span class="p">:</span> <span class="n">quick_gelu</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">GATED_ACT_2_ACT</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s1">'swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'fast-swiglu'</span><span class="p">:</span> <span class="s1">'silu'</span><span class="p">,</span>
|
||
<span class="s1">'geglu'</span><span class="p">:</span> <span class="s1">'gelu'</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
|
||
<div class="viewcode-block" id="is_gated_activation">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.is_gated_activation">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Is a given activation function gated?</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> True if the function is gated, False otherwise.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">ACT2FN</span>
|
||
<span class="k">return</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">GATED_ACT_2_ACT</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="non_gated_version">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.non_gated_version">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">non_gated_version</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Given an activation function, get the non-gated version.</span>
|
||
|
||
<span class="sd"> If the activation function is non-gated, it returns the same activation</span>
|
||
<span class="sd"> function name.</span>
|
||
|
||
<span class="sd"> For example, that function returns 'silu' for 'swiglu' and 'relu' for</span>
|
||
<span class="sd"> 'relu'.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> activation : str</span>
|
||
<span class="sd"> The name of the activation function.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The name of the non-gated activation function.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">if</span> <span class="n">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">GATED_ACT_2_ACT</span><span class="p">[</span><span class="n">activation</span><span class="p">]</span>
|
||
<span class="k">return</span> <span class="n">activation</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="lora_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.lora_plugin">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">lora_plugin</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">in_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">out_hidden_sizes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">max_low_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">lora_ranks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">weight_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> in_hidden_size/out_hidden_size : int</span>
|
||
<span class="sd"> the lora computation workflow is</span>
|
||
<span class="sd"> [M, in_hidden_size] -> [M, low_rank] -> [M, out_hidden_size]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor = None</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> transa : bool</span>
|
||
<span class="sd"> Is the first input transposed? Set to 'True' if you want the first</span>
|
||
<span class="sd"> input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> transb : bool</span>
|
||
<span class="sd"> Is the second input transposed? Set to 'True' if you want the</span>
|
||
<span class="sd"> second input to be transposed, 'False' otherwise.</span>
|
||
|
||
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> max_low_rank : int</span>
|
||
<span class="sd"> Maximum low_rank, used to determine the workspace size.</span>
|
||
|
||
<span class="sd"> lora_ranks : cpu Tensor with shape [batch_size]</span>
|
||
<span class="sd"> The low_rank of each request</span>
|
||
|
||
<span class="sd"> lora_weights_pointers : cpu int64 Tensor with shape [batch_size, 3]</span>
|
||
<span class="sd"> The weights pointers of each request. Consist of in_pointer, out_pointer and possibly a scales vector pointer.</span>
|
||
|
||
<span class="sd"> weight_index : int</span>
|
||
<span class="sd"> The index of weight if the weight pointer pointing to multiple weights.</span>
|
||
|
||
<span class="sd"> Return:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_creator_list</span>
|
||
<span class="n">in_hidden_size_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"in_hidden_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">in_hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">out_hidden_size_field_list</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="sa">f</span><span class="s2">"out_hidden_size_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">o</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
|
||
<span class="p">]</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transa"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
|
||
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"transb"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'Lora'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">max_low_rank_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"max_low_rank"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_low_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">weight_index_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"weight_index"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">weight_index</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">num_lora_modules</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
|
||
<span class="n">num_lora_modules_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"num_lora_modules"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">in_hidden_size_field</span><span class="p">,</span> <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">num_lora_modules_field</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
|
||
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">max_low_rank_field</span><span class="p">,</span> <span class="n">weight_index_field</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="n">out_hidden_size_field_list</span><span class="p">)</span>
|
||
<span class="n">lora_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"lora"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">),</span> <span class="n">host_request_types</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="n">lora_ranks</span> <span class="o">+</span> <span class="n">lora_weights_pointers</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lora_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">num_lora_modules</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="p">[</span>
|
||
<span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">i</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">)</span>
|
||
<span class="p">]</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="dora_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.dora_plugin">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">dora_plugin</span><span class="p">(</span><span class="n">activations</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">out_hidden_sizes</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> The DoRA plugin applies column-wise scaling to the output of a LoRA layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> out_hidden_sizes : list[int]</span>
|
||
<span class="sd"> The output hidden size of each adapter in the related LoRA module.</span>
|
||
<span class="sd"> For example, for a qkv projection out_hidden_sizes should be [q_dim, k_dim, v_dim].</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor = None</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> Return:</span>
|
||
<span class="sd"> The tensor produced by that layer.</span>
|
||
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
|
||
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
|
||
|
||
<span class="n">dora_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_creator</span><span class="p">(</span>
|
||
<span class="s1">'Dora'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">dora_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">out_hidden_sizes</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"out_hidden_sizes"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">lora_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
|
||
<span class="n">type_id</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">lora_dtype</span><span class="p">)),</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">type_id</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">out_hidden_sizes</span><span class="p">])</span>
|
||
|
||
<span class="n">dora_plug</span> <span class="o">=</span> <span class="n">dora_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"dora"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">TensorRTPhase</span><span class="o">.</span><span class="n">BUILD</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">activations</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">lora_dtype</span><span class="p">),</span> <span class="n">host_request_types</span>
|
||
<span class="p">]</span> <span class="o">+</span> <span class="n">lora_weights_pointers</span>
|
||
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v3</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="p">[],</span> <span class="n">dora_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">dora_plg_creator</span><span class="p">,</span> <span class="s2">"dora"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">activations</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="mamba_conv1d">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mamba_conv1d">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mamba_conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">conv_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dconv</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">pre_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">post_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">apply_silu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> conv_state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The conv state tensor. Its shape is [batch_size, dconv - 1, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> conv_weight : Tensor (On GPU)</span>
|
||
<span class="sd"> The weight tensor. Its shape is [1, dconv, dim]</span>
|
||
|
||
<span class="sd"> conv_bias : Tensor (On GPU)</span>
|
||
<span class="sd"> The bias tensor. Its shape is [dim]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The hidden dimension of conv1d</span>
|
||
|
||
<span class="sd"> dconv : int</span>
|
||
<span class="sd"> The window size of conv1d</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> pre_stride : int = 0</span>
|
||
<span class="sd"> The (pre) stride size of the input tensor.</span>
|
||
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
|
||
|
||
<span class="sd"> post_stride : int = 0</span>
|
||
<span class="sd"> The (post) stride size of the input tensor.</span>
|
||
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dconv, dim]</span>
|
||
|
||
<span class="sd"> apply_silu: bool</span>
|
||
<span class="sd"> Is there a SiLU operation after the conv1d? When True apply</span>
|
||
<span class="sd"> SiLU activation function after the conv1d.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">mamba_conv1d_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'MambaConv1d'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">mamba_conv1d_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dconv</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dconv"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dconv</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pre_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"pre_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pre_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">post_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"post_stride"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">post_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">apply_silu</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"apply_silu"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">apply_silu</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">dconv</span><span class="p">,</span> <span class="n">pre_stride</span><span class="p">,</span> <span class="n">post_stride</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
|
||
<span class="n">paged_state</span><span class="p">,</span> <span class="n">apply_silu</span>
|
||
<span class="p">])</span>
|
||
<span class="n">mamba_conv1d_plug</span> <span class="o">=</span> <span class="n">mamba_conv1d_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"mamba_conv1d"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">conv_state_or_ptr</span><span class="p">,</span> <span class="n">conv_weight</span><span class="p">,</span> <span class="n">conv_bias</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">mamba_conv1d_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">mamba_conv1d_plg_creator</span><span class="p">,</span> <span class="s2">"mamba_conv1d"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="selective_scan">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.selective_scan">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">selective_scan</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">delta</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">delta_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">BC</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">D</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dstate</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dt_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">delta_softplus</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">z</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">nheads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">ngroups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
|
||
<span class="n">mamba_version</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">'Mamba1'</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The ssm state tensor. Its shape is [batch_size, dstate, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> delta : Tensor (On GPU)</span>
|
||
<span class="sd"> The delta tensor.</span>
|
||
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding</span>
|
||
|
||
<span class="sd"> delta_bias : Tensor (On GPU)</span>
|
||
<span class="sd"> The delta bias tensor.</span>
|
||
<span class="sd"> mamba: Its shape is [dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> A : Tensor (On GPU)</span>
|
||
<span class="sd"> A matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [dstate, dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> BC : Tensor (On GPU)</span>
|
||
<span class="sd"> B and C matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding</span>
|
||
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding</span>
|
||
|
||
<span class="sd"> D : Tensor (On GPU)</span>
|
||
<span class="sd"> D matrix.</span>
|
||
<span class="sd"> mamba: Its shape is [dim]</span>
|
||
<span class="sd"> mamba2: Its shape is [nheads]</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The inner dimension of SSM block</span>
|
||
|
||
<span class="sd"> dstate : int</span>
|
||
<span class="sd"> The state dimension of SSM block</span>
|
||
|
||
<span class="sd"> dt_rank: int</span>
|
||
<span class="sd"> The rank dimension of dt_proj</span>
|
||
|
||
<span class="sd"> delta_softplus : bool</span>
|
||
<span class="sd"> Do we apply softplus to the delta.</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> z : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
|
||
|
||
<span class="sd"> nheads: int (Optional)</span>
|
||
<span class="sd"> The number of heads.</span>
|
||
|
||
<span class="sd"> ngroups: int (Optional)</span>
|
||
<span class="sd"> The number of groups.</span>
|
||
|
||
<span class="sd"> chunk_size: int (Optional)</span>
|
||
<span class="sd"> The chunk_size is used for the chunk_scan kernel.</span>
|
||
|
||
<span class="sd"> mamba_version: int (Optional)</span>
|
||
<span class="sd"> Mamba version, support Mamba1 as default.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">selective_scan_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'SelectiveScan'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">selective_scan_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dstate</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dstate"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dstate</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">dt_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dt_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dt_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"nheads"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">nheads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">ngroups</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"ngroups"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">ngroups</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">chunk_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"chunk_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">delta_softplus</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"delta_softplus"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">delta_softplus</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"z_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"z_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">is_mamba2</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_mamba2"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">mamba_version</span> <span class="o">==</span> <span class="s1">'Mamba2'</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">dstate</span><span class="p">,</span> <span class="n">dt_rank</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">ngroups</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">delta_softplus</span><span class="p">,</span>
|
||
<span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">z_enabled</span><span class="p">,</span> <span class="n">is_mamba2</span>
|
||
<span class="p">])</span>
|
||
<span class="n">selective_scan_plug</span> <span class="o">=</span> <span class="n">selective_scan_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"selective_scan"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span> <span class="n">state_or_ptr</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="n">delta_bias</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">BC</span><span class="p">,</span> <span class="n">D</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">z</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">selective_scan_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">selective_scan_plg_creator</span><span class="p">,</span> <span class="s2">"selective_scan"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="rg_lru">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rg_lru">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">rg_lru</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">y</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">y_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_x</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_x_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_a</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gate_a_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor (On GPU)</span>
|
||
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> A : Tensor (On GPU)</span>
|
||
<span class="sd"> A matrix. Its shape is [dim]</span>
|
||
|
||
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
|
||
<span class="sd"> The lru state tensor. Its shape is [batch_size, dstate, dim]</span>
|
||
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
|
||
|
||
<span class="sd"> host_request_types : Tensor (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
|
||
|
||
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
|
||
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
|
||
<span class="sd"> sequences in the batch.</span>
|
||
|
||
<span class="sd"> dim : int</span>
|
||
<span class="sd"> The inner dimension of RG_LRU block</span>
|
||
|
||
<span class="sd"> block_size : int</span>
|
||
<span class="sd"> The block size of the block diagonal linear layer. It is used to</span>
|
||
<span class="sd"> support the cases that enable fused gate.</span>
|
||
|
||
<span class="sd"> dtype: str</span>
|
||
<span class="sd"> data type</span>
|
||
|
||
<span class="sd"> y : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The y tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> y_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The y_bias tensor. Its shape is [dim]. If y_bias is not None, we</span>
|
||
<span class="sd"> will fuse GELU(y + y_bias) in this function.</span>
|
||
|
||
<span class="sd"> gate : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate tensor. Its shape is [batch_size, seq_len, 2 * dim].</span>
|
||
<span class="sd"> If gate is not None, we will fuse the gate_x and gate_a, otherwise</span>
|
||
<span class="sd"> use those two tensors.</span>
|
||
|
||
<span class="sd"> gate_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_bias tensor. Its shape is [2 * block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> gate_x : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_x tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> gate_x_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_x_bias tensor. Its shape is [block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_x_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> gate_a : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_a tensor. Its shape is [batch_size, seq_len, dim]</span>
|
||
|
||
<span class="sd"> gate_a_bias : Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> The gate_a_bias tensor. Its shape is [block_num, dim // block_num].</span>
|
||
<span class="sd"> If gate_a_bias is not None, we will fuse the bias add in this function.</span>
|
||
|
||
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
|
||
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">lru_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'LRU'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">lru_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">enable_fuse_gate</span> <span class="o">=</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">has_gate_bias</span> <span class="o">=</span> <span class="p">(</span><span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">assert</span> <span class="n">block_size</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_x</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"dim"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"block_size"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"remove_input_padding"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"paged_state"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
|
||
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_enabled"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"y_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"fuse_gate_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"fuse_gate_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"gate_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"gate_bias_enabled"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
|
||
<span class="n">dim</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">y_enabled</span><span class="p">,</span>
|
||
<span class="n">y_bias_enabled</span><span class="p">,</span> <span class="n">fuse_gate_enabled</span><span class="p">,</span> <span class="n">gate_bias_enabled</span>
|
||
<span class="p">])</span>
|
||
<span class="n">lru_plug</span> <span class="o">=</span> <span class="n">lru_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"rg_lru"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="nb">input</span><span class="p">,</span>
|
||
<span class="n">A</span><span class="p">,</span>
|
||
<span class="n">state_or_ptr</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">,</span>
|
||
<span class="n">last_token_ids</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y_bias</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_bias</span><span class="p">]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x</span><span class="p">,</span> <span class="n">gate_a</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
|
||
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x_bias</span><span class="p">,</span> <span class="n">gate_a_bias</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lru_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">lru_plg_creator</span><span class="p">,</span> <span class="s2">"rg_lru"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="topk">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.topk">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">topk</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
|
||
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">largest</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an topk operation.</span>
|
||
|
||
<span class="sd"> As explained in the ONNX documentation,</span>
|
||
|
||
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk</span>
|
||
|
||
<span class="sd"> NOTE: One distinction from the ONNX topk op, the output is always sorted</span>
|
||
<span class="sd"> with TensorRT layer.</span>
|
||
|
||
<span class="sd"> Retrieve the top-K largest elements along a specified axis.</span>
|
||
<span class="sd"> Given an input tensor of shape [a_1, a_2, ..., a_n, r]</span>
|
||
<span class="sd"> and integer argument k, return two outputs:</span>
|
||
<span class="sd"> Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the values of the top k elements along the specified axis</span>
|
||
<span class="sd"> Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the indices of the top k elements (original indices from the input tensor).</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor.</span>
|
||
|
||
<span class="sd"> k : int</span>
|
||
<span class="sd"> A single positive value corresponding to the number of top elements to retrieve</span>
|
||
|
||
<span class="sd"> dim: int</span>
|
||
<span class="sd"> The dimension in which to compute the topk indices.</span>
|
||
|
||
<span class="sd"> largest: bool</span>
|
||
<span class="sd"> Controls whether to return largest or smallest elements</span>
|
||
|
||
<span class="sd"> prefer_plugin : bool</span>
|
||
<span class="sd"> Whether to use the topkLastDim plugin if dim is last dim and k is static.</span>
|
||
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensors (values, indices) produced by this topk operation.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
|
||
<span class="k">if</span> <span class="n">prefer_plugin</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
|
||
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
||
<span class="c1"># since we might need to flatten the input to 2d tensor,</span>
|
||
<span class="c1"># we need to prepare the output shape</span>
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
|
||
<span class="n">out_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
|
||
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shape</span> <span class="o">+</span> <span class="p">[</span><span class="n">k</span><span class="p">])</span>
|
||
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
|
||
<span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
|
||
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
|
||
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"TopkLastDim"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="n">is_largest</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"is_largest"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">k</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"k"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">is_largest</span><span class="p">])</span>
|
||
<span class="n">topk_last_dim_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"topk_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">topk_last_dim_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"topk_last_dim"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">values</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># non-plugin path</span>
|
||
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span>
|
||
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">,</span>
|
||
<span class="n">k</span><span class="o">=</span><span class="n">k</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">axes</span><span class="o">=</span><span class="n">axes</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">k</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="n">k</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
|
||
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
|
||
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="scatter_nd">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter_nd">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">scatter_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input: Tensor</span>
|
||
<span class="sd"> The input tensor to be updated</span>
|
||
<span class="sd"> mask: Tensor</span>
|
||
<span class="sd"> A tensor of indices specifying the locations in data to be updated.</span>
|
||
<span class="sd"> source: Tensor</span>
|
||
<span class="sd"> A tensor of values to be written or scattered into data.</span>
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> New tensor with the same shape as the input tensor data,</span>
|
||
<span class="sd"> where the values from the source tensor are scattered or written into the output tensor</span>
|
||
<span class="sd"> at the locations specified by the mask tensor.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="low_latency_gemm">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">low_latency_gemm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">strict_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM is only support with plugin"</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span> <span class="o">!=</span> <span class="s2">"fp8"</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM plugin only support fp8"</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"LowLatencyGemm"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="p">((</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)</span> <span class="ow">or</span> <span class="p">((</span><span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)):</span>
|
||
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">"Low Latency GEMM only support fp8 input"</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">alpha</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
|
||
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span>
|
||
<span class="o">==</span> <span class="mi">1</span><span class="p">),</span> <span class="s2">"`alpha` must be passed as a float32 ndarray"</span>
|
||
<span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span> <span class="k">if</span> <span class="n">alpha</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
|
||
<span class="n">alpha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"alpha"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">strict_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strict_dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">strict_dtype</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">p_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">]):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"strict_dtype must be float32, float16 or bfloat16 in low latency gemm plugin"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
|
||
<span class="s2">"need to use strict dtype in low latency gemm plugin fp8"</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">alpha</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">low_latency_gemm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"low_latency_gemm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">low_latency_gemm_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"low_latency_gemm"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SideStreamIDType">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.SideStreamIDType">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SideStreamIDType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
|
||
<span class="n">disable</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">moe</span> <span class="o">=</span> <span class="mi">1</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="low_latency_gemm_swiglu">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm_swiglu">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">low_latency_gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
|
||
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
|
||
|
||
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
|
||
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
|
||
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The first tensor (often called A).</span>
|
||
|
||
<span class="sd"> weight : Tensor</span>
|
||
<span class="sd"> The second tensor (often called B).</span>
|
||
|
||
<span class="sd"> scale_d0 : float</span>
|
||
<span class="sd"> The scale for dequantizing x, used for fp8</span>
|
||
|
||
<span class="sd"> scale_d1 : float</span>
|
||
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
|
||
|
||
<span class="sd"> scale_output : float</span>
|
||
<span class="sd"> The scale for quantizing output, used for fp8</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The tensor produced by the inserted layer.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s1">'LowLatencyGemmSwiglu'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d0"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_d1"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"scale_output"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">])</span>
|
||
<span class="n">low_latency_gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
|
||
<span class="s2">"low_latency_gemm_swiglu"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
|
||
<span class="n">low_latency_gemm_swiglu_plug</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cuda_stream_sync">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cuda_stream_sync">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">cuda_stream_sync</span><span class="p">(</span><span class="n">input_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
|
||
<span class="n">side_stream_id</span><span class="p">:</span> <span class="n">SideStreamIDType</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Wait for the side stream on the main stream.</span>
|
||
<span class="sd"> output = input_list[0]</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input_list : List[Tensor] (On GPU)</span>
|
||
<span class="sd"> The list of input tensors.</span>
|
||
<span class="sd"> side_stream_id : int (On CPU)</span>
|
||
<span class="sd"> The side stream ID.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
|
||
<span class="s2">"CudaStream"</span><span class="p">,</span> <span class="s2">"1"</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">p_side_stream_id</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"side_stream_id"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">side_stream_id</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">p_num_inputs</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"num_inputs"</span><span class="p">,</span>
|
||
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_list</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
|
||
<span class="s2">"type_id"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">p_side_stream_id</span><span class="p">,</span> <span class="n">p_num_inputs</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
|
||
<span class="n">plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"cuda_stream"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_list</span><span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"cuda_stream"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">output</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="cp_split_plugin">
|
||
<a class="viewcode-back" href="../../python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cp_split_plugin">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">cp_split_plugin</span><span class="p">(</span>
|
||
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">Tensor</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Add an operation to perform splitting for context parallelism.</span>
|
||
|
||
<span class="sd"> This operation split the input_ids into cp_size chunks, and return the cp_rank-th</span>
|
||
<span class="sd"> chunk.</span>
|
||
<span class="sd"> When the seqlen % cp_size != 0, the chunk sizes of each rank would be</span>
|
||
<span class="sd"> [seqlen // cp_size, seqlen // cp_size, ..., seqlen - (seqlen // cp_size) * cp_size]</span>
|
||
|
||
<span class="sd"> It inserts a IPluginV3Layer.</span>
|
||
|
||
<span class="sd"> Parameters:</span>
|
||
<span class="sd"> input : Tensor</span>
|
||
<span class="sd"> The input tensor contains the indices to split.</span>
|
||
|
||
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
|
||
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
|
||
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
|
||
<span class="sd"> in docs/gpt_attention.md,</span>
|
||
|
||
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
|
||
<span class="sd"> A host tensor that contains the lengths of the different inputs</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> The output split tensor.</span>
|
||
<span class="sd"> The length of the output split tensor.</span>
|
||
<span class="sd"> The index for rebuilding the sequence</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_creator</span><span class="p">(</span>
|
||
<span class="s1">'CpSplit'</span><span class="p">,</span> <span class="s1">'1'</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_size"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_size</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">"cp_rank"</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
|
||
|
||
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">])</span>
|
||
<span class="n">cp_split_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">"cp_split"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">,</span>
|
||
<span class="n">trt</span><span class="o">.</span><span class="n">TensorRTPhase</span><span class="o">.</span><span class="n">BUILD</span><span class="p">)</span>
|
||
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
|
||
<span class="n">input_ids</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">host_request_types</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
|
||
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">trt_tensor</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v3</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="p">[],</span> <span class="n">cp_split_plug</span><span class="p">)</span>
|
||
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">"cp_split"</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
|
||
<span class="n">layer</span><span class="p">),</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
|
||
|
||
</pre></div>
|
||
|
||
</article>
|
||
|
||
|
||
|
||
|
||
|
||
<footer class="prev-next-footer d-print-none">
|
||
|
||
<div class="prev-next-area">
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="bd-sidebar-secondary"></div>
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
<footer class="bd-footer-content">
|
||
|
||
</footer>
|
||
|
||
</main>
|
||
</div>
|
||
</div>
|
||
|
||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||
<script defer src="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||
<script defer src="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||
|
||
<footer class="bd-footer">
|
||
<div class="bd-footer__inner bd-page-width">
|
||
|
||
<div class="footer-items__start">
|
||
|
||
<div class="footer-item">
|
||
<a class="footer-brand logo" href="https://www.nvidia.com">
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
|
||
<img src="../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
|
||
</a></div>
|
||
|
||
<div class="footer-item">
|
||
|
||
<div class="footer-links">
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Manage My Privacy</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/preferences/start/">Do Not Sell or Share My Data</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
|
||
|
||
|
||
|
||
</div>
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
|
||
|
||
|
||
<p class="copyright">
|
||
|
||
Copyright © 2025, NVidia.
|
||
<br/>
|
||
|
||
</p>
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
<div class="extra_footer">
|
||
|
||
<p>Last updated on September 02, 2025.</p>
|
||
|
||
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/e81c50d">e81c50d</a>.</p>
|
||
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</footer>
|
||
</body>
|
||
</html> |