TensorRT-LLMs/_modules/tensorrt_llm/functional.html
2026-01-08 05:44:03 +00:00

8802 lines
1.1 MiB
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en" data-content_root="../../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.functional &#8212; TensorRT LLM</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
.pst-js-only { display: none !important; }
</style>
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=8f2a1f02" />
<link rel="stylesheet" type="text/css" href="../../_static/styles/nvidia-sphinx-theme.css?v=933278ad" />
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css?v=76b2166b" />
<link rel="stylesheet" type="text/css" href="../../_static/autodoc_pydantic.css" />
<link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css?v=13237357" />
<link rel="stylesheet" type="text/css" href="../../_static/config_selector.css?v=e17d8078" />
<link rel="stylesheet" type="text/css" href="../../_static/custom.css?v=19d20f17" />
<!-- So that users can add custom icons -->
<script src="../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../_static/doctools.js?v=9a2dae69"></script>
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
<script src="../../_static/config_selector.js?v=aaf6cd4a"></script>
<script>let toggleHintShow = 'Click to show';</script>
<script>let toggleHintHide = 'Click to hide';</script>
<script>let toggleOpenOnPrint = 'true';</script>
<script src="../../_static/togglebutton.js?v=4a39c7ea"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/functional';</script>
<script>
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.2.0rc7';
DOCUMENTATION_OPTIONS.show_version_warning_banner =
false;
</script>
<link rel="icon" href="../../_static/favicon.png"/>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="1.2.0rc7" />
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<dialog id="pst-search-dialog">
<form class="bd-search d-flex align-items-center"
action="../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form>
</dialog>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
<div class="bd-header__inner bd-page-width">
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
<span class="fa-solid fa-bars"></span>
</button>
<div class="col-lg-3 navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-2"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-2"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-2"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-2">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
</div>
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<dialog id="pst-primary-sidebar-modal"></dialog>
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a>
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-3"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-3"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-3"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-3">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<nav class="bd-docs-nav bd-links"
aria-label="Table of Contents">
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../installation/index.html">Installation</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../installation/containers.html">Pre-built release container images on NGC</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference.html">Generate text</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sparse_attention.html">Sparse Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_offloading.html">KV Cache Offloading</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/aiperf_client.html">Aiperf Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/aiperf_client_for_multimodal.html">Aiperf Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client.html">Curl Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_completion_client.html">Curl Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_responses_client.html">Curl Responses Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_responses_client.html">OpenAI Responses Client</a></li>
</ul>
</details></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/dynamo_k8s_example.html">Dynamo K8s Example</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../deployment-guide/index.html">Model Recipes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-deepseek-r1-on-trtllm.html">Deployment Guide for DeepSeek R1 on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-llama3.3-70b-on-trtllm.html">Deployment Guide for Llama3.3 70B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-llama4-scout-on-trtllm.html">Deployment Guide for Llama4 Scout 17B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-gpt-oss-on-trtllm.html">Deployment Guide for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-qwen3-on-trtllm.html">Deployment Guide for Qwen3 on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-qwen3-next-on-trtllm.html">Deployment Guide for Qwen3 Next on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-kimi-k2-thinking-on-trtllm.html">Deployment Guide for Kimi K2 Thinking on TensorRT LLM - Blackwell</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Models</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../models/supported-models.html">Supported Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../models/adding-new-model.html">Adding a New Model</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">CLI Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-bench.html">trtllm-bench</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-eval.html">trtllm-eval</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/index.html">LLM API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/reference.html">API Reference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Features</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../features/feature-combination-matrix.html">Feature Combination Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/disagg-serving.html">Disaggregated Serving</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/kvcache.html">KV Cache System</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/long-sequence.html">Long Sequences</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/lora.html">LoRA (Low-Rank Adaptation)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/multi-modality.html">Multimodal Support in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/overlap-scheduler.html">Overlap Scheduler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/paged-attention-ifb-scheduler.html">Paged Attention, IFB, and Request Scheduling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/parallel-strategy.html">Parallelism in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/sampling.html">Sampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/additional-outputs.html">Additional Outputs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/guided-decoding.html">Guided Decoding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/speculative-decoding.html">Speculative Decoding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/checkpoint-loading.html">Checkpoint Loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/auto_deploy/auto-deploy.html">AutoDeploy (Beta)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/ray-orchestrator.html">Ray Orchestrator (Prototype)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/torch_compile_and_piecewise_cuda_graph.html">Torch Compile &amp; Piecewise CUDA Graph</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/helix.html">Helix Parallelism</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/kv-cache-connector.html">KV Cache Connector</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/overview.html">Architecture Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/perf-analysis.html">Performance Analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/perf-benchmarking.html">TensorRT LLM Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/ci-overview.html">Continuous Integration Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/dev-containers.html">Using Dev Containers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/api-change.html">LLM API Change Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/kv-transfer.html">Introduction to KV Cache Transmission</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog11_GPT_OSS_Eagle3.html">Running GPT-OSS-120B with Eagle3 Speculative Decoding on GB200/B200 (TensorRT LLM)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog12_Combining_Guided_Decoding_and_Speculative_Decoding.html">Combining Guided Decoding and Speculative Decoding: Making CPU and GPU Cooperate Seamlessly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog13_Inference_Time_Compute_Implementation_in_TensorRT-LLM.html">Inference Time Compute Implementation in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html">Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT LLM server</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-GramSpeculativeDecodingin TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT LLM (Part 2: Performance Status and Optimization)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.html">How to get best performance on DeepSeek-R1 in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Quick Links</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/releases">Releases</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM">Github Code</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap">Roadmap</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumb" class="d-print-none">
<ul class="bd-breadcrumbs">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.functional</span></li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<h1>Source code for tensorrt_llm.functional</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">weakref</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">OrderedDict</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">enum</span><span class="w"> </span><span class="kn">import</span> <span class="n">IntEnum</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">partial</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
<span class="c1"># isort: off</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
<span class="c1"># isort: on</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.</span><span class="w"> </span><span class="kn">import</span> <span class="n">graph_rewriting</span> <span class="k">as</span> <span class="n">gw</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._common</span><span class="w"> </span><span class="kn">import</span> <span class="n">default_net</span><span class="p">,</span> <span class="n">default_trtnet</span><span class="p">,</span> <span class="n">precision</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">bf16_array</span><span class="p">,</span> <span class="n">bool_array</span><span class="p">,</span>
<span class="n">dim_resolve_negative</span><span class="p">,</span> <span class="n">dim_to_trt_axes</span><span class="p">,</span> <span class="n">dims_array</span><span class="p">,</span>
<span class="n">fp16_array</span><span class="p">,</span> <span class="n">fp32_array</span><span class="p">,</span> <span class="n">get_sm_version</span><span class="p">,</span> <span class="n">int32_array</span><span class="p">,</span>
<span class="n">int64_array</span><span class="p">,</span> <span class="n">np_dtype_to_trt</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span>
<span class="n">trt_dtype_to_np</span><span class="p">,</span> <span class="n">trt_dtype_to_str</span><span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.network</span><span class="w"> </span><span class="kn">import</span> <span class="n">PluginInfo</span><span class="p">,</span> <span class="n">set_np_weight</span><span class="p">,</span> <span class="n">set_plugin_info</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">,</span> <span class="n">current_all_reduce_helper</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantMode</span>
<div class="viewcode-block" id="DimRange">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.DimRange">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">DimRange</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.</span>
<span class="sd"> For example, tensor has 2 dimensions. Then the data members are:</span>
<span class="sd"> self.min = [dim 0 min, dim 1 min]</span>
<span class="sd"> self.opt = [dim 0 opt, dim 1 opt]</span>
<span class="sd"> self.max = [dim 0 max, dim 1 max]</span>
<span class="sd"> For static dimension, it has min==opt==max, thus the shape param in the ctor can be an integer</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
<span class="n">names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,</span>
<span class="sd"> where N is the number of dimensions for a tensor.</span>
<span class="sd"> When one element is an integer, it means that dimension is static.</span>
<span class="sd"> Otherwise, when one element is a tuple/list, it means the dimension is dynamic.</span>
<span class="sd"> The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts</span>
<span class="sd"> 0 &lt;= min &lt;= opt &lt;= max.</span>
<span class="sd"> Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with</span>
<span class="sd"> min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.</span>
<span class="sd"> The shape parameter could be:</span>
<span class="sd"> [16, (1, 8, 32), 8]</span>
<span class="sd"> It has same semantics of</span>
<span class="sd"> [(16, 16, 16), (1, 8, 32), (8, 8, 8)]</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">=</span> <span class="n">names</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">names</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">shape</span>
<span class="p">),</span> <span class="s2">&quot;Expecting shape list and name list must have same length, got {shape=}, {name=}&quot;</span>
<span class="k">for</span> <span class="n">dim</span> <span class="ow">in</span> <span class="n">shape</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> \
<span class="s2">&quot;Each dimension must specify a 3-elements tuple or list in the order of (min,opt,max), got {dim=}&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__value</span><span class="p">:</span> <span class="nb">object</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">__value</span><span class="p">,</span> <span class="n">DimRange</span><span class="p">)</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">dimension_names</span> <span class="ow">and</span> \
<span class="bp">self</span><span class="o">.</span><span class="n">min</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">min</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">opt</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">opt</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max</span> <span class="o">==</span> <span class="n">__value</span><span class="o">.</span><span class="n">max</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">min</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">opt</span><span class="si">=}</span><span class="s2">, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max</span><span class="si">=}</span><span class="s2">)&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span></div>
<div class="viewcode-block" id="Tensor">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">Tensor</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The class to represent dense tensors.</span>
<span class="sd"> A dense tensor is named, has a shape and contains typed elements. Each</span>
<span class="sd"> dimension of a tensor can either be static or dynamic. Static dimensions</span>
<span class="sd"> are known at engine compilation by TensorRT. Dynamic dimensions can take</span>
<span class="sd"> values determined at runtime. The tensor can be located on the host (CPU)</span>
<span class="sd"> or the device (GPU).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dim_range</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">location</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">TensorLocation</span><span class="o">.</span><span class="n">DEVICE</span><span class="p">,</span>
<span class="n">network</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">trt_tensor</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> name : str</span>
<span class="sd"> The name of the tensor.</span>
<span class="sd"> dtype : tensorrt.DataType</span>
<span class="sd"> The type of the elements of the tensor. See the TensorRT</span>
<span class="sd"> documentation for list of supported data types.</span>
<span class="sd"> shape : tensorrt.Dims</span>
<span class="sd"> The dimensions of the tensor. In TensorRT-LLM, tensors can have</span>
<span class="sd"> static or dynamic dimensions (it is possible to mix static and</span>
<span class="sd"> dynamic dimensions). A static dimension is known when the</span>
<span class="sd"> TensorRT engine is built. A dynamic dimension can be set when</span>
<span class="sd"> the engine is executed. Use -1 for dynamic dimensions.</span>
<span class="sd"> dim_range : OrderedDict</span>
<span class="sd"> An ordered dictionary (the positions of the elements matter)</span>
<span class="sd"> that associates a name and a range of values to the dimensions.</span>
<span class="sd"> For a static dimension, the range must be limited to a single</span>
<span class="sd"> value. For a dynamic dimension, the range is defined by three</span>
<span class="sd"> values [min, opt, max] where min and max are, respectively, the</span>
<span class="sd"> smallest and largest possible values of that dimension. The</span>
<span class="sd"> opt value is used by TensorRT to optimize the engine for the</span>
<span class="sd"> most common case.</span>
<span class="sd"> Assume there is N optimization profiles, each item dim_range dict is ordered by:</span>
<span class="sd"> (dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])</span>
<span class="sd"> or it&#39;s following when the dimension is static (can think as min==opt==max):</span>
<span class="sd"> (static dimension name : [profile 0 value, profile 1 value, ... profile N value])</span>
<span class="sd"> For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)</span>
<span class="sd"> And number of keys is equal to number of optimization profiles.</span>
<span class="sd"> is_network_input : bool</span>
<span class="sd"> A boolean indicating if that tensor is an input of the network.</span>
<span class="sd"> Inputs must be provided by the user to run the engine.</span>
<span class="sd"> location : tensorrt.TensorLocation</span>
<span class="sd"> A flag to indicate where the tensor will be located. It can be</span>
<span class="sd"> on the host (CPU) or the device (GPU).</span>
<span class="sd"> network: Network</span>
<span class="sd"> A parent Network instance, that helps to fine the users of this tensor.</span>
<span class="sd"> trt_tensor: trt.ITensor</span>
<span class="sd"> Construct with the ITensor instance directly, and no shape profiles are required.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Layout of self.profiles</span>
<span class="c1"># Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
<span class="c1"># Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M</span>
<span class="c1"># ...</span>
<span class="c1"># Opt profile N: dim 0 ... dim M</span>
<span class="c1"># So from the dim_range arg to self.profiles conversion, there is a layout transpose</span>
<span class="c1"># dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span> <span class="o">=</span> <span class="p">[]</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># specially for the graph rewriter</span>
<span class="c1"># work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter</span>
<span class="k">if</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">assert</span> <span class="n">network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">is_network_input</span><span class="p">,</span> <span class="s2">&quot;is_network_input should be False when trt_tensor is not None&quot;</span>
<span class="k">return</span>
<span class="c1"># be cautious here, the weakref is critical to avoid circular referencing before Network and Tensor</span>
<span class="c1"># using strong reference will likely cause significant peak memory increase, since Network objects</span>
<span class="c1"># holds the weights data.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_network</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">default_net</span><span class="p">())</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_network_input</span> <span class="o">=</span> <span class="n">is_network_input</span>
<span class="k">if</span> <span class="n">is_network_input</span><span class="p">:</span>
<span class="k">if</span> <span class="n">dim_range</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim_range</span><span class="p">,</span> <span class="n">OrderedDict</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">dim_range</span>
<span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Each input tensor shall have at least one dimension, tensor &#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&#39; found </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">&quot;</span>
<span class="n">found_profiles</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">len</span><span class="p">(</span><span class="n">ranges</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
<span class="p">]</span>
<span class="k">assert</span> <span class="nb">all</span><span class="p">(</span>
<span class="p">[</span><span class="n">x</span> <span class="o">==</span> <span class="n">found_profiles</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">found_profiles</span><span class="p">]</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;Expecting all the dimensions in the dim_range has same number of profiles, tensor &#39;</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&#39; got </span><span class="si">{</span><span class="n">dim_range</span><span class="si">=}</span><span class="s2">&quot;</span>
<span class="n">num_opt_profile</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
<span class="k">assert</span> <span class="n">num_opt_profile</span> <span class="o">&gt;=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_opt_profile</span><span class="p">):</span>
<span class="n">range_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">dimension_names</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dim_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">ranges</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">))</span>
<span class="n">range_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">ranges</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="n">dimension_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">profiles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">DimRange</span><span class="p">(</span><span class="n">range_shape</span><span class="p">,</span> <span class="n">dimension_names</span><span class="p">))</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">dim_range</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
<span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="bp">self</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">network</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_network</span><span class="p">()</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The name of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span>
<span class="nd">@name</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span><span class="w"> </span><span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the name of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The type of the elements in the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span>
<span class="nd">@dtype</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span><span class="w"> </span><span class="nf">dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the type of the elements in the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The shape of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="nd">@shape</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the shape of the tensor. See __init__.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">shape</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">shape</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The physical location of the tensor (on the host or the device).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span>
<span class="nd">@location</span><span class="o">.</span><span class="n">setter</span>
<span class="k">def</span><span class="w"> </span><span class="nf">location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">location</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Set the physical location of the tensor (on the host or the device). See __init__.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">location</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">location</span> <span class="o">=</span> <span class="n">location</span>
<div class="viewcode-block" id="Tensor.mark_output">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mark_output">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">name</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Mark a tensor as a network output.</span>
<span class="sd"> When a tensor is marked as an output, its content can be obtained after</span>
<span class="sd"> the execution of the TensorRT engine. The user is responsible for</span>
<span class="sd"> allocating buffers to store the output tensors when preparing the</span>
<span class="sd"> execution of the TensorRT engine.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
<span class="k">def</span><span class="w"> </span><span class="fm">__add__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.add.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__radd__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.add.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">add</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__sub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sub.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__rsub__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sub.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sub</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__mul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mul.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__rmul__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mul.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mul</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__truediv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.div.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">div</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__floordiv__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.floordiv.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">floordiv</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__mod__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.floordiv.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">modulo</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__lt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.lt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">lt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__gt__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.gt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">gt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__eq__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_tensor_wrapper</span><span class="p">:</span>
<span class="c1"># for graph rewriter</span>
<span class="k">return</span> <span class="nb">hash</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">==</span> <span class="nb">hash</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># for creating the network</span>
<span class="k">return</span> <span class="n">eq</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__ge__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Maps to functional.gt or functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__gt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__le__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Maps to functional.lt or functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">op_or</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="fm">__lt__</span><span class="p">(</span><span class="n">b</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__eq__</span><span class="p">(</span><span class="n">b</span><span class="p">))</span>
<div class="viewcode-block" id="Tensor.view">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.view">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.view.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.flatten">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.flatten">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.flatten.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">flatten</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.permute">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.permute">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.permute.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.transpose">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.transpose">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.transpose.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim0</span><span class="p">,</span> <span class="n">dim1</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.mean">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.mean">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.mean.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">mean</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.max">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.max">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.max.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">max</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.abs">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.abs">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.abs.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.sqrt">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.sqrt">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.sqrt.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.squeeze">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.squeeze">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">squeeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.squeeze.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">squeeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.unsqueeze">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.unsqueeze">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">unsqueeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.squeeze.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.log">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.log">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.log.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">log</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.cast">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.cast">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.cast.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.size">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.size">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the shape of the tensor if the dim parameter is None.</span>
<span class="sd"> Otherwise, returns a size of the dimension indicated by dim. The</span>
<span class="sd"> behavior is undefined if dim is negative or exceeds the rank of the</span>
<span class="sd"> tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span></div>
<div class="viewcode-block" id="Tensor.rank">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.rank">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">rank</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.ndim">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.ndim">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">ndim</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Returns the rank (i.e. the number of dimensions) of the tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span></div>
<div class="viewcode-block" id="Tensor.split">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.split">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.split.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">split_size_or_sections</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.select">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.select">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.select.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">select</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">index</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.unbind">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.unbind">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.unbind.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">unbind</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.repeat">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.repeat">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">repeat</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sizes</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> See functional.repeat</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">repeat</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sizes</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.is_dynamic">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_dynamic">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_dynamic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> If the argument &#39;dim&#39; is None, that function returns a boolean that</span>
<span class="sd"> indicates if the tensor contains a dynamic dimension (True) or not</span>
<span class="sd"> (False). In that case, the first dimension is excluded (as it usually</span>
<span class="sd"> corresponds to the batch size). If the argument is an integer, that</span>
<span class="sd"> functions returns a boolean that indicates if the dimension &#39;dim&#39; is</span>
<span class="sd"> dynamic (True) or not (False).</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">s</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">return</span> <span class="kc">False</span></div>
<span class="c1"># graph writer related functions</span>
<div class="viewcode-block" id="Tensor.get_parent">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_parent">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">get_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; Get the layer that produces this tensor. &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_parent</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.get_users">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.get_users">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">get_users</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; Get the layers that use this tensor as an input. &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">get_tensor_users</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.replace_all_uses_with">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.replace_all_uses_with">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">replace_all_uses_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Replace all uses of this tensor as an input to consumer layers</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">network</span><span class="o">.</span><span class="n">is_graph_altered</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">users</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_users</span><span class="p">()</span>
<span class="k">for</span> <span class="n">user</span> <span class="ow">in</span> <span class="n">users</span><span class="p">:</span>
<span class="n">inputs_changed</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">num_inputs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">user</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">(</span><span class="n">i</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">:</span>
<span class="n">inputs_changed</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">user</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">new_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">inputs_changed</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;Tensor not found in layer inputs&quot;</span>
<span class="c1"># update the FLayerMetadata as well</span>
<span class="n">flayer</span> <span class="o">=</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="n">flayer</span> <span class="ow">and</span> <span class="n">flayer</span><span class="o">.</span><span class="n">replace_input_with</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">new_tensor</span><span class="p">)</span></div>
<div class="viewcode-block" id="Tensor.is_trt_wrapper">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Tensor.is_trt_wrapper">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_trt_wrapper</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Check if there is a trt.ITensor member inside, which is required for</span>
<span class="sd"> graph rewriter. In order to differentiate usages, it may be necessary</span>
<span class="sd"> to have an inheritance hierarchy.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">&#39;trt_tensor&#39;</span><span class="p">):</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span></div>
<span class="k">def</span><span class="w"> </span><span class="fm">__hash__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_trt_wrapper</span><span class="p">():</span>
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">id</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;TensorRT LLM Tensor: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="si">=}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">=}</span><span class="s2">&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__xor__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Maps to functional.gt or functional.eq.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;self.shape: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">, b.shape: </span><span class="si">{</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;a.shape: </span><span class="si">{</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">, b.shape: </span><span class="si">{</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">op_xor</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">_create_tensor</span><span class="p">(</span><span class="n">trt_tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span> <span class="n">producer</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ILayer</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> A helper function to create a TensorRT LLM Tensor object that encapsulates</span>
<span class="sd"> the connection between the TensorRT tensor (trt.ITensor) and the layer</span>
<span class="sd"> (trt.ILayer) that produces it.</span>
<span class="sd"> That function is expected to be used as:</span>
<span class="sd"> # Insert a new layer in the network using the TensorRT API:</span>
<span class="sd"> layer = default_trtnet().add_&lt;some_layer&gt;(...)</span>
<span class="sd"> # Extract the first output of that layer and connect it to the layer.</span>
<span class="sd"> return _create_tensor(layer.get_output(0), layer)</span>
<span class="sd"> That function also sets the precision of the layer/producer to the default</span>
<span class="sd"> precision of the network.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> trt_tensor : trt.ITensor</span>
<span class="sd"> The TensorRT tensor to connect to its producer (the layer).</span>
<span class="sd"> producer : trt.ILayer</span>
<span class="sd"> The producer.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The TensorRT LLM tensor (functional.Tensor) that encapsulates the</span>
<span class="sd"> TensorRT tensor and the layer that produces it. The former is</span>
<span class="sd"> accessible through the attribute &#39;trt_tensor&#39; and the latter using the</span>
<span class="sd"> attribute &#39;producer&#39;.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">trt_tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="c1"># Set the layer name since this is the only</span>
<span class="c1"># centralized location to pass the name from</span>
<span class="c1"># module space to the TRT IR</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">_set_layer_name</span><span class="p">(</span><span class="n">producer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="fm">__len__</span><span class="p">(</span>
<span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;tensor </span><span class="si">{</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s2"> has an invalid shape&quot;</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">name</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="n">is_network_input</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="o">=</span> <span class="n">trt_tensor</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">producer</span> <span class="o">=</span> <span class="n">producer</span>
<span class="c1"># tb.print_stack(limit=10) # FOR DEBUGGING: filter producer.name if needed</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">SHAPE</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">GATHER</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONCATENATION</span>
<span class="p">]:</span>
<span class="n">producer</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">assert</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">gw</span><span class="o">.</span><span class="n">FLayerInfoMemo</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">cur_flayer</span><span class="o">.</span><span class="n">layer_name</span> <span class="o">=</span> <span class="n">producer</span><span class="o">.</span><span class="n">name</span>
<span class="k">return</span> <span class="n">tensor</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plugin_creator</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">IPluginCreator</span><span class="p">,</span>
<span class="n">plugin_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pfc</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plugin_info</span> <span class="o">=</span> <span class="n">PluginInfo</span><span class="p">(</span><span class="n">plugin_creator</span><span class="p">,</span> <span class="n">plugin_name</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">set_plugin_info</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">trt_network</span><span class="p">,</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">plugin_info</span><span class="p">)</span>
<div class="viewcode-block" id="RotaryScalingType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">RotaryScalingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">none</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">linear</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">dynamic</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">longrope</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">llama3</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">6</span>
<div class="viewcode-block" id="RotaryScalingType.from_string">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RotaryScalingType.from_string">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="k">return</span> <span class="n">RotaryScalingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported rotary scaling type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="PositionEmbeddingType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">PositionEmbeddingType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">learned_absolute</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">rope_gptj</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">rope_gpt_neox</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">long_rope</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">alibi</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">alibi_with_scale</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">relative</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">chatglm</span> <span class="o">=</span> <span class="mi">7</span>
<span class="n">yarn</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">mrope</span> <span class="o">=</span> <span class="mi">9</span>
<span class="n">deferred</span> <span class="o">=</span> <span class="mi">10</span> <span class="c1"># Apply customized positional embedding by using an external embedder. K will be cached before embedding.</span>
<div class="viewcode-block" id="PositionEmbeddingType.is_rope">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_rope">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_rope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_rope</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mrope</span>
<span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.is_mrope">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_mrope">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_mrope</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">mrope</span><span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.is_alibi">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_alibi">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_alibi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">alibi</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">alibi_with_scale</span><span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.is_deferred">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.is_deferred">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_deferred</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span> <span class="ow">in</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">deferred</span><span class="p">]</span></div>
<div class="viewcode-block" id="PositionEmbeddingType.choices">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.choices">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">choices</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
<span class="k">return</span> <span class="p">[</span><span class="n">embedding</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">embedding</span> <span class="ow">in</span> <span class="n">PositionEmbeddingType</span><span class="p">]</span></div>
<span class="k">def</span><span class="w"> </span><span class="fm">__str__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span>
<div class="viewcode-block" id="PositionEmbeddingType.from_string">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.PositionEmbeddingType.from_string">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_string</span><span class="p">(</span><span class="n">s</span><span class="p">):</span>
<span class="k">try</span><span class="p">:</span>
<span class="k">return</span> <span class="n">PositionEmbeddingType</span><span class="p">[</span><span class="n">s</span><span class="p">]</span>
<span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported position embedding type: </span><span class="si">{</span><span class="n">s</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="AttentionMaskType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AttentionMaskType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">AttentionMaskType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">padding</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">causal</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">sliding_window_causal</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">bidirectional</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">bidirectionalglm</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># TODO: merge this mask into bidirectional</span>
<span class="n">blocksparse</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">custom_mask</span> <span class="o">=</span> <span class="mi">6</span></div>
<div class="viewcode-block" id="LayerNormType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">LayerNormType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">LayerNorm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">RmsNorm</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">GroupNorm</span> <span class="o">=</span> <span class="mi">2</span></div>
<div class="viewcode-block" id="LayerNormPositionType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.LayerNormPositionType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">LayerNormPositionType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">pre_layernorm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">post_layernorm</span> <span class="o">=</span> <span class="mi">1</span></div>
<div class="viewcode-block" id="MLPType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MLPType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">MLPType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">MLP</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">GatedMLP</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">FusedGatedMLP</span> <span class="o">=</span> <span class="mi">2</span></div>
<div class="viewcode-block" id="SliceInputType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.SliceInputType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">SliceInputType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">data</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">start</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">size</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">stride</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">fill_value</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">axes</span> <span class="o">=</span> <span class="mi">5</span></div>
<div class="viewcode-block" id="activation">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.activation">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">activation</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an activation function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> act_type : trt.ActivationType</span>
<span class="sd"> The type of the activation (RELU, TANH, SIGMOID, ...).</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> relu for op=trt.ActivationType.RELU</span>
<span class="sd"> tanh for op=trt.ActivationType.TANH</span>
<span class="sd"> sigmoid for op=trt.ActivationType.SIGMOID</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">act_type</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="int_clip">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.int_clip">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">int_clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">lower</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">upper</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">lower</span> <span class="o">&lt;=</span> <span class="n">upper</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Lower bound must be less than or equal to upper bound i.e. </span><span class="si">{</span><span class="n">lower</span><span class="si">}</span><span class="s2"> &lt;= </span><span class="si">{</span><span class="n">upper</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">minimum</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">upper</span><span class="p">)</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">maximum</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">lower</span><span class="p">)</span>
<span class="k">return</span> <span class="n">res</span></div>
<div class="viewcode-block" id="clip">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.clip">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">clip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a CLIP operation that sets the range to [alpha, beta].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> alpha : float</span>
<span class="sd"> The lower bound of the CLIP function.</span>
<span class="sd"> beta : float</span>
<span class="sd"> The upper bound of the CLIP function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">CLIP</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
<span class="n">layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="n">relu</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">RELU</span><span class="p">)</span>
<span class="n">tanh</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">TANH</span><span class="p">)</span>
<span class="n">sigmoid</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">activation</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SIGMOID</span><span class="p">)</span>
<div class="viewcode-block" id="silu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.silu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">silu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a SiLU (`x * sigmoid(x)`) operation.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span></div>
<div class="viewcode-block" id="swiglu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.swiglu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a SwiGLU (`x * SiLU(gate)`) operation.</span>
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
<span class="sd"> dimension, applies SiLU to the second half and multiply the results. The</span>
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">x</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">silu</span><span class="p">(</span><span class="n">gate</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span></div>
<div class="viewcode-block" id="squared_relu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squared_relu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">squared_relu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a Squared ReLU operation.</span>
<span class="sd"> This function applies ReLU and squares the output.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="nb">pow</span><span class="p">(</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="mf">2.0</span><span class="p">)</span></div>
<div class="viewcode-block" id="cast">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cast">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a cast operation.</span>
<span class="sd"> For an input tensor of type INT8, this function sets the dynamic range of</span>
<span class="sd"> the input to [-127, 127] for automatic dequantization. For a cast into</span>
<span class="sd"> INT8, that function sets the dynamic range of the output to [-127, 127] for</span>
<span class="sd"> automatic quantization.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the cast is applied.</span>
<span class="sd"> dtype : str or trt.DataType</span>
<span class="sd"> The data type of the output tensor after the cast. When &#39;dtype&#39; is</span>
<span class="sd"> provided as a string, it must be a name amongst the valid names.</span>
<span class="sd"> See _str_to_trt_dtype_dict in _utils.py for a list of supported</span>
<span class="sd"> types and type names.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">):</span>
<span class="n">cvt_dtype</span> <span class="o">=</span> <span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">cvt_dtype</span><span class="p">:</span>
<span class="c1"># If input type and cast dtype are the same, do nothing</span>
<span class="k">return</span> <span class="nb">input</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">cvt_dtype</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">&#39;int8&#39;</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cvt_dtype</span> <span class="o">==</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="s1">&#39;int8&#39;</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="flip">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flip">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">flip</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Reverses the order of an n-D tensor along given axis in dims.</span>
<span class="sd"> That flip operation maps to a TensorRT ISliceLayer. For the dimensions</span>
<span class="sd"> listed in dims it copies the elements from the last one to the first one</span>
<span class="sd"> (from (N-1) down to 0 with a step of -1). For the dimensions not in &#39;dims&#39;,</span>
<span class="sd"> it copies the elements from the first one to the last one (from 0 to N-1</span>
<span class="sd"> with a step of 1).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the cast is applied.</span>
<span class="sd"> dims : list or tuple</span>
<span class="sd"> The axes to flip. Negative indices are supported.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dims</span><span class="p">):</span>
<span class="k">assert</span> <span class="o">-</span><span class="n">ndim</span> <span class="o">&lt;=</span> <span class="n">value</span> <span class="o">&lt;</span> <span class="n">ndim</span>
<span class="k">if</span> <span class="o">-</span><span class="n">ndim</span> <span class="o">&lt;=</span> <span class="n">value</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dims</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dims</span><span class="p">))</span>
<span class="n">start_values</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">stride_values</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dims</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="n">start_values</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(),</span>
<span class="n">stride</span><span class="o">=</span><span class="n">stride_values</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="interpolate">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.interpolate">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">interpolate</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">scale_factor</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;nearest&#39;</span><span class="p">,</span>
<span class="n">align_corners</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">recompute_scale_factor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">antialias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">assert</span> <span class="mi">2</span> <span class="o">&lt;</span> <span class="n">input_ndim</span> <span class="o">&lt;</span> <span class="mi">6</span><span class="p">,</span> <span class="s2">&quot;Only 3D, 4D and 5D input Tensors supported&quot;</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span>
<span class="n">scale_factor</span>
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="s2">&quot;Only one of out_shape or scales should be defined&quot;</span>
<span class="k">assert</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="s1">&#39;linear&#39;</span><span class="p">,</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="s1">&#39;bicubic&#39;</span><span class="p">,</span> <span class="s1">&#39;trilinear&#39;</span><span class="p">,</span>
<span class="s1">&#39;nearest-exact&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;trilinear&#39;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">5</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;trilinear only supports 5D tensor&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;bilinear&quot;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;bilinear only supports 4D tensor&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;linear&quot;</span> <span class="ow">and</span> <span class="n">input_ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;linear only supports 3D tensor&quot;</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_resize</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="n">scale_factor</span><span class="p">:</span>
<span class="n">scale_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span>
<span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">))</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">scale_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">scale_factor</span><span class="p">,</span> <span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="nb">int</span><span class="p">)):</span>
<span class="n">updated_scale</span> <span class="o">=</span> <span class="p">[</span><span class="n">scale_factor</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">updated_scale</span> <span class="o">=</span> <span class="n">scale_factor</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">updated_scale</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span>
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">size_len</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">size</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">size_len</span> <span class="o">==</span> <span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span>
<span class="k">if</span> <span class="n">size_len</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">updated_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">updated_size</span> <span class="o">=</span> <span class="n">size</span>
<span class="n">updated_shape</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">input_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="mi">2</span> <span class="k">else</span> <span class="n">updated_size</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">2</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">layer</span><span class="o">.</span><span class="n">shape</span> <span class="o">=</span> <span class="n">updated_shape</span>
<span class="k">if</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="s1">&#39;nearest-exact&#39;</span><span class="p">]</span> <span class="ow">or</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;linear&#39;</span><span class="p">,</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="s1">&#39;trilinear&#39;</span><span class="p">]:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">LINEAR</span>
<span class="k">if</span> <span class="n">align_corners</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ALIGN_CORNERS</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="c1"># TODO, need to confirm the align_corners effect on bilinear mode.</span>
<span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="k">elif</span> <span class="n">mode</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;bicubic&#39;</span><span class="p">]:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">CUBIC</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">HALF_PIXEL</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">resize_mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
<span class="n">layer</span><span class="o">.</span><span class="n">coordinate_transformation</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ResizeCoordinateTransformation</span><span class="o">.</span><span class="n">ASYMMETRIC</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="matmul">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.matmul">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">matmul</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">use_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a matrix multiplication.</span>
<span class="sd"> That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained</span>
<span class="sd"> in the TensorRT documentation, it computes the inner product between the</span>
<span class="sd"> two inputs after applying an optional transposition on the inputs.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first tensor (often called A).</span>
<span class="sd"> mat2 : Tensor</span>
<span class="sd"> The second tensor (often called B).</span>
<span class="sd"> transa : bool</span>
<span class="sd"> Is the first input transposed? Set to &#39;True&#39; if you want the first</span>
<span class="sd"> input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> transb : bool</span>
<span class="sd"> Is the second input transposed? Set to &#39;True&#39; if you want the</span>
<span class="sd"> second input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> use_fp32_acc: bool</span>
<span class="sd"> Set to &#39;True&#39; if for accuracy reason, this fp16 matmul needs to use</span>
<span class="sd"> fp32 accumulation. This can be a per model and per matmul decision.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># This option is only supported for fp16, but not bf16 or any other precisions.</span>
<span class="n">use_fp32_acc</span> <span class="o">=</span> <span class="n">use_fp32_acc</span> <span class="ow">and</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span> <span class="ow">and</span> <span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span>
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="n">mat2</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">mat2</span><span class="p">,</span> <span class="s1">&#39;float32&#39;</span><span class="p">)</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mat2</span><span class="p">)</span>
<span class="n">op0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transa</span> \
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
<span class="n">op1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">TRANSPOSE</span> <span class="k">if</span> <span class="n">transb</span> \
<span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">MatrixOperation</span><span class="o">.</span><span class="n">NONE</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_matrix_multiply</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op0</span><span class="p">,</span>
<span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op1</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_fp32_acc</span><span class="p">:</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="s2">&quot;float16&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="gemm_swiglu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gemm_swiglu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first tensor (often called A).</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The second tensor (often called B).</span>
<span class="sd"> bias : Optional[Tensor]</span>
<span class="sd"> The per-channel bias. The plugin with fp8 dtype does not support bias yet.</span>
<span class="sd"> scale_d0 : float</span>
<span class="sd"> The scale for dequantizing x, used for fp8</span>
<span class="sd"> scale_d1 : float</span>
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
<span class="sd"> scale_output : float</span>
<span class="sd"> The scale for quantizing output, used for fp8</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;GemmSwiglu&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span>
<span class="k">if</span> <span class="n">p_dtype</span> <span class="o">==</span> <span class="s2">&quot;fp8&quot;</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;fp8 gemm_swiglu does not support bias yet&quot;</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_has_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;has_bias&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="mi">0</span> <span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_d0&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_d1&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_output&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_has_bias</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">])</span>
<span class="n">gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;gemm_swiglu&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="c1"># TODO(anchengc) pass nullptr when no bias</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">weight</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">gemm_swiglu_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="constant">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">constant</span><span class="p">(</span><span class="n">ndarray</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
<span class="n">as_dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">as_shape</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a constant layer.</span>
<span class="sd"> TensorRT graphs encapsulate constant values in the form of constant layers</span>
<span class="sd"> (tensorrt.IConstantLayer). That function creates such a layer from a Numpy</span>
<span class="sd"> array of values. After compilation of the network by TensorRT, those</span>
<span class="sd"> weights are stored in the serialized TensorRT engine.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> ndarray : numpy.ndarray</span>
<span class="sd"> The array of values (weights) encapsulated by this constant layer.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="n">np_dtype_to_trt</span><span class="p">(</span><span class="n">ndarray</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">if</span> <span class="n">as_dtype</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">as_dtype</span>
<span class="n">trt_shape</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span>
<span class="n">ndarray</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">if</span> <span class="n">as_shape</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">(</span><span class="n">as_shape</span><span class="p">)</span>
<span class="n">trt_count</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">trt_shape</span><span class="p">)):</span>
<span class="n">trt_count</span> <span class="o">*=</span> <span class="n">trt_shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">(</span><span class="n">trt_dtype</span><span class="p">,</span> <span class="n">ndarray</span><span class="o">.</span><span class="n">ctypes</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">trt_count</span><span class="p">)</span>
<span class="c1"># Prevent underlying numpy array from going out of scope</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">register_ndarray</span><span class="p">(</span><span class="n">ndarray</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_constant</span><span class="p">(</span><span class="n">trt_shape</span><span class="p">,</span> <span class="n">weights</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_output_type</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">trt_dtype</span><span class="p">)</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="c1"># TODO: remove this WAR after https://nvbugs/4359151 fixed.</span>
<span class="n">set_np_weight</span><span class="p">(</span><span class="n">default_trtnet</span><span class="p">(),</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">ndarray</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<span class="c1"># TODO: TensorRT uses sizes of the output dimensions.</span>
<span class="c1"># DL framework uses ends usually. Will change it to ends.</span>
<div class="viewcode-block" id="slice">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.slice">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">slice</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">starts</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">sizes</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">strides</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mode</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">fill_value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to extract a slice from a tensor.</span>
<span class="sd"> As described in the TensorRT documentation of the ISliceLayer, the slice</span>
<span class="sd"> layer has two variants: Static and dynamic.</span>
<span class="sd"> For static slicing, this function takes the starts and sizes values in the</span>
<span class="sd"> different dimensions to slice at layer creation time via a sequence of</span>
<span class="sd"> integers. For dynamic slicing, it accepts starts and sizes as</span>
<span class="sd"> tensorrt.ITensor`s.</span>
<span class="sd"> The slice layer selects for each dimension a start location from within the</span>
<span class="sd"> input tensor, and copies elements to the output tensor using a stride of 1</span>
<span class="sd"> across the input tensor. Start and size tensors must be 1-D int32 shape</span>
<span class="sd"> tensors if not specified as a sequence of integers.</span>
<span class="sd"> As an example, on input = [[0, 2, 4], [1, 3, 5]], the call to</span>
<span class="sd"> slice(input, start=[1, 0], size=[1, 2])</span>
<span class="sd"> will produce the tensor [[1, 3]] as output. The slice operator when</span>
<span class="sd"> executed by TensorRT will copy one row (because size[0] == 1) starting from</span>
<span class="sd"> the 2nd row (because start[0] == 1) and two columns (size[1] == 2) starting</span>
<span class="sd"> from the 1st column (because start[1] == 0).</span>
<span class="sd"> In pseudo-code the behavior of that operation can be described as follows</span>
<span class="sd"> for a 2D tensor (and easily be extended to more dimensions):</span>
<span class="sd"> output = Tensor(shape=sizes)</span>
<span class="sd"> for ii in range(sizes[0]):</span>
<span class="sd"> for jj in range(sizes[1]):</span>
<span class="sd"> output[ii][jj] = input[starts[0]+ii][starts[1]+jj]</span>
<span class="sd"> Note that it is common in deep-learning frameworks to use ranges</span>
<span class="sd"> [start:end] for similar operations. It can be emulated by setting the sizes</span>
<span class="sd"> argument such that in each dimension [start:start+size] == [start:end] i.e.</span>
<span class="sd"> size = end-start.</span>
<span class="sd"> TensorRT supports different slice modes but that function restricts that</span>
<span class="sd"> choice to `mode == tensorrt.SampleMode.STRICT_BOUNDS`.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the slicing is performed.</span>
<span class="sd"> starts : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The starting points, in the input tensor, and each dimension.</span>
<span class="sd"> sizes : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The number of elements in each dimension of the sliced tensor (output).</span>
<span class="sd"> strides : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The step be taken from start, in input tensor.</span>
<span class="sd"> mode : trt.SampleMode</span>
<span class="sd"> The mode that controls how the slice operation handles out of bounds coordinates.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the slice layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">input_ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">trt_starts</span> <span class="o">=</span> <span class="n">starts</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">trt_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="n">sizes</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">trt_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="n">trt_strides</span> <span class="o">=</span> <span class="n">strides</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">or</span> <span class="n">strides</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">trt_strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">input_ndim</span><span class="p">)]</span>
<span class="k">if</span> <span class="n">fill_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
<span class="n">fill_value</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">fill_value</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="n">trt_starts</span><span class="p">,</span>
<span class="n">shape</span><span class="o">=</span><span class="n">trt_sizes</span><span class="p">,</span>
<span class="n">stride</span><span class="o">=</span><span class="n">trt_strides</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">starts</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">starts</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sizes</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">sizes</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strides</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span><span class="o">.</span><span class="n">FILL</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_value</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">fill_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="pad">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.pad">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">pad</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">pad</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">],</span>
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;constant&#39;</span><span class="p">,</span>
<span class="n">value</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a pad layer.</span>
<span class="sd"> The padding layer adds zero-padding at the start and end of the input tensor. And the</span>
<span class="sd"> padding size by which to pad some dimensions of input are described starting from the</span>
<span class="sd"> last dimension and moving forward.</span>
<span class="sd"> `[len(pad) / 2]` dimensions of input will be padded. For example, to pad only the last</span>
<span class="sd"> dimension of the input tensor, then pad has the form [padding_left, padding_right]; to</span>
<span class="sd"> pad the last 2 dimensions of the input tensor, then use [padding_left, padding_right,</span>
<span class="sd"> padding_top, padding_bottom]; to pad the last 3 dimensions, use [padding_left,</span>
<span class="sd"> padding_right, padding_top, padding_bottom, padding_front, padding_back].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the padding_2d is performed.</span>
<span class="sd"> pad : sequence of int</span>
<span class="sd"> An m-elements tuple for padding, where its length m meets the requirement that</span>
<span class="sd"> m &lt;= 2*input dimensions, and m is even.</span>
<span class="sd"> mode : str</span>
<span class="sd"> Only \&#39;constant\&#39; is supported.</span>
<span class="sd"> value : float</span>
<span class="sd"> Fill value for &#39;constant&#39; padding. Default: 0.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;constant&quot;</span><span class="p">,</span> <span class="s2">&quot;Only `&#39;constant&#39;` is supported now.&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="nb">len</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">2</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="p">),</span> <span class="s2">&quot;The length of `pad` should be even and less than 2*input.ndim&quot;</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">pad</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">2</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="p">),</span> <span class="s2">&quot;The length of `pad` should be even and less than 2*input.ndim&quot;</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">pad</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s2">&quot;int32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;pad type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="n">pad</span><span class="p">)</span><span class="si">}</span><span class="s2"> not supported&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">value</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)),</span>
<span class="n">pad</span><span class="p">])</span> <span class="c1"># pre-padding the indices</span>
<span class="n">padding_index</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">padding_index</span><span class="p">[</span><span class="o">-</span><span class="p">(</span><span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):]</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">pad</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># reverse the indices</span>
<span class="n">pad</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">index</span><span class="o">=</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">padding_index</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)))</span>
<span class="n">pre_padding</span><span class="p">,</span> <span class="n">post_padding</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">pad</span><span class="p">,</span> <span class="n">chunks</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">start</span> <span class="o">=</span> <span class="p">(</span><span class="n">pre_padding</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="o">*</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">&#39;int32&#39;</span><span class="p">)</span>
<span class="n">extend_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">pre_padding</span> <span class="o">+</span> <span class="n">post_padding</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="n">size</span> <span class="o">=</span> <span class="p">(</span><span class="n">extend_size</span> <span class="o">+</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="s1">&#39;int32&#39;</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">(),</span>
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">(),</span>
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">layer</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">SampleMode</span><span class="o">.</span><span class="n">FILL</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">start</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">size</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="n">SliceInputType</span><span class="o">.</span><span class="n">fill_value</span><span class="p">,</span>
<span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="rand">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rand">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">low</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">high</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;float32&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> shape: Tensor</span>
<span class="sd"> The shape of the tensor needed to be generated.</span>
<span class="sd"> low: float</span>
<span class="sd"> The minimum value (inclusive) of the range used for random.</span>
<span class="sd"> high: float</span>
<span class="sd"> The maximum value (inclusive) of the range used for random.</span>
<span class="sd"> dtype: Union[str, trt.DataType]</span>
<span class="sd"> The desired data type for the output tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The generated random tensor produced by the fill layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># NOTE: DISABLED FOR NOW UNTIL THE FILL LAYER (RANDOM_UNIFORM) in TRT IS FIXED</span>
<span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s2">&quot;The rand() op is temporarily disabled.&quot;</span>
<span class="n">low</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">low</span><span class="p">))</span>
<span class="n">high</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">(</span><span class="n">high</span><span class="p">))</span>
<span class="n">trt_dtype</span> <span class="o">=</span> <span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span> <span class="k">else</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">RANDOM_UNIFORM</span><span class="p">,</span>
<span class="n">trt_dtype</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">low</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">high</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="categorical_sample">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.categorical_sample">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">categorical_sample</span><span class="p">(</span><span class="n">probs</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rand_data</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> This is a sampling operation and an equivalent of torch.distributions.Categorical.sample()</span>
<span class="sd"> i.e. given a probability distribution tensor, it samples an index of that tensor.</span>
<span class="sd"> See: https://pytorch.org/docs/stable/distributions.html#torch.distributions.categorical.Categorical.sample</span>
<span class="sd"> NOTE: This assumes that the given probabilities are **not** normalized.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> probs: Tensor</span>
<span class="sd"> A 1-D floating point tensor representing the probability distributions.</span>
<span class="sd"> rand_data: Tensor (optional)</span>
<span class="sd"> A random tensor of same shape as `probs` tensor.</span>
<span class="sd"> If not provided, this function will add a rand() op to generate it and use for sampling.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor containing a single index of the `probs` tensor representing the sample.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">probs</span> <span class="o">=</span> <span class="n">probs</span> <span class="o">/</span> <span class="nb">sum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">rand_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">assert</span> <span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">probs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">rand_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">rand_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rand_data</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">rand</span><span class="p">(</span><span class="n">rand_shape</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">rand_shape</span> <span class="o">==</span> <span class="n">shape</span><span class="p">(</span><span class="n">rand_data</span><span class="p">)</span>
<span class="n">rand_data</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">rand_data</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">shape</span><span class="p">(</span><span class="n">probs</span><span class="p">))</span>
<span class="n">cum_probs</span> <span class="o">=</span> <span class="n">cumsum</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">cmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">cum_probs</span> <span class="o">&gt;=</span> <span class="n">rand_data</span><span class="p">,</span> <span class="n">probs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">samples</span> <span class="o">=</span> <span class="n">argmax</span><span class="p">(</span><span class="n">cmp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">samples</span></div>
<div class="viewcode-block" id="Conditional">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">Conditional</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to conditionally execute two code paths/subgraphs.</span>
<span class="sd"> Usage:</span>
<span class="sd"> 1. conditional = Conditional(condition)</span>
<span class="sd"> 2. input_1_ = conditional.add_input(input_1)</span>
<span class="sd"> ...</span>
<span class="sd"> input_n_ = conditional.add_input(input_n)</span>
<span class="sd"> 3. Construct the graph to get true_output_value and false_output_value using input_1_, ..., input_n_</span>
<span class="sd"> 4. output = conditional.add_output(true_output_value, false_output_value)</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_if_conditional</span><span class="p">()</span>
<span class="k">if</span> <span class="n">condition</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">view</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="p">[])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">set_condition</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<div class="viewcode-block" id="Conditional.add_input">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_input">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">add_input</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">in_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_input</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">in_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">in_node</span><span class="p">)</span></div>
<div class="viewcode-block" id="Conditional.add_output">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.Conditional.add_output">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">add_output</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">true_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">false_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">out_node</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer</span><span class="o">.</span><span class="n">add_output</span><span class="p">(</span><span class="n">true_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">false_value</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">out_node</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">out_node</span><span class="p">)</span></div>
</div>
<span class="c1"># TODO: support step.</span>
<div class="viewcode-block" id="arange">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.arange">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">arange</span><span class="p">(</span><span class="n">start</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="n">end</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to fill a 1D tensor.</span>
<span class="sd"> The tensor is filled with the values between start and end with a step of 1</span>
<span class="sd"> between the different elements. In pseudo-code, it corresponds to a tensor</span>
<span class="sd"> populated with the values:</span>
<span class="sd"> output = Tensor([dtype(ii) for ii in range(start, end, 1)])</span>
<span class="sd"> For example, a call to arange(3, 6, &#39;int32&#39;) will add an operation to the</span>
<span class="sd"> TensorRT graph that will produce [3, 4, 5] when executed. The call to</span>
<span class="sd"> arange(2, 5, &#39;float32&#39;) will add a layer to generate [2.0, 3.0, 4.0].</span>
<span class="sd"> This operation is implemented using a tensorrt.IFillLayer in</span>
<span class="sd"> trt.FillOperation.LINSPACE mode.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> start : Union[Tensor, int]</span>
<span class="sd"> The starting point of the range.</span>
<span class="sd"> end : Union[Tensor, int]</span>
<span class="sd"> The end point of the range.</span>
<span class="sd"> dtype : str</span>
<span class="sd"> The type of the elements. See _str_to_trt_dtype_dict in _utils.py</span>
<span class="sd"> for a list of supported types and type names.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the fill layer. It is a 1D tensor containing</span>
<span class="sd"> `end-start` elements of type `dtype`.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">res_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
<span class="n">array_func</span> <span class="o">=</span> <span class="n">int32_array</span> <span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="k">else</span> <span class="n">int64_array</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_func</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
<span class="k">assert</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">or</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span>
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">:</span>
<span class="k">if</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="c1"># end == trt.int64</span>
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">&quot;int32&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">&quot;int64&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span> <span class="c1"># start == trt.int64 and end == trt.int32</span>
<span class="k">if</span> <span class="n">res_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="s2">&quot;int32&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">end</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">end</span><span class="p">,</span> <span class="s2">&quot;int64&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
<span class="k">assert</span> <span class="n">start</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;start type (</span><span class="si">{</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">) != end type (</span><span class="si">{</span><span class="n">end</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">)&quot;</span>
<span class="n">step</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">num</span> <span class="o">=</span> <span class="n">end</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">num</span> <span class="o">=</span> <span class="n">num</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_fill</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">trt</span><span class="o">.</span><span class="n">FillOperation</span><span class="o">.</span><span class="n">LINSPACE</span><span class="p">,</span>
<span class="n">start</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">start</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 0</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">step</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span> <span class="c1"># rank = 1</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">res_dtype</span><span class="p">:</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<div class="viewcode-block" id="expand">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">expand</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">expand_shape</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand a tensor.</span>
<span class="sd"> The operation expands the input tensor in the singleton dimensions to the</span>
<span class="sd"> size indicated by the corresponding dimension in the `expand_shape` tensor.</span>
<span class="sd"> In other words, given an input tensor with dimensions of size 1, those</span>
<span class="sd"> dimensions will be expanded to the size in `expand_shape`.</span>
<span class="sd"> For example, a tensor of shape [4, 3, 1, 3] will be expanded to a tensor of</span>
<span class="sd"> shape [4, 3, 2, 3] by the layer created using expand(input, [4, 3, 2, 3]).</span>
<span class="sd"> The expansion may either replicate the values or be mapped to a view with a</span>
<span class="sd"> stride of 0 in the expanded dimensions. For example, for a tensor [[3, 2]] of</span>
<span class="sd"> shape [1, 2],</span>
<span class="sd"> expand([[3, 2]], [2, 2])</span>
<span class="sd"> can be used to expand the input to [[3, 2], [3, 2]].</span>
<span class="sd"> This operation is implemented using a tensorrt.ISliceLayer. The current</span>
<span class="sd"> implementation does not verify that non singleton dimensions are not</span>
<span class="sd"> shrunk. In other words, for an input of shape [4, 1, 2],</span>
<span class="sd"> expand(input, [3, 2, 2])</span>
<span class="sd"> will produce a tensor of shape [3, 2, 2]. That behavior is subject to</span>
<span class="sd"> change in the future.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> expand_shape : Tensor</span>
<span class="sd"> The new shape of the expanded tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the expand layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_slice</span><span class="p">(</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span>
<span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)],</span> <span class="c1"># unused dummy value</span>
<span class="n">stride</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span> <span class="c1"># unused dummy value</span>
<span class="p">)</span>
<span class="c1"># The stride is either:</span>
<span class="c1"># 0 for dimensions of size 1 (i.e. shape(input, i) - 1 == 1 - 1 == 0) or,</span>
<span class="c1"># 1 for dimensions of size &gt; 1 since minimum(value &gt;= 1, 1) == 1.</span>
<span class="n">stride_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span>
<span class="p">[</span><span class="n">minimum</span><span class="p">((</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">expand_shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="einsum">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.einsum">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">einsum</span><span class="p">(</span><span class="n">einsum_eq</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an Einsum operation.</span>
<span class="sd"> That operation maps to tensorrt.IEinsumLayer. As explained in the TensorRT</span>
<span class="sd"> documentation, this layer implements a summation over the elements of the</span>
<span class="sd"> inputs along dimensions specified by the equation parameter, based on the</span>
<span class="sd"> Einstein summation convention. The layer can have one or more inputs of</span>
<span class="sd"> rank &gt;= 0. All the inputs must be of same data type. This layer supports</span>
<span class="sd"> all TensorRT data types except bool. There is one output tensor of the same</span>
<span class="sd"> type as the input tensors. The shape of output tensor is determined by the</span>
<span class="sd"> equation.</span>
<span class="sd"> The equation specifies ASCII lower-case letters for each dimension in the</span>
<span class="sd"> inputs in the same order as the dimensions, separated by comma for each</span>
<span class="sd"> input. The dimensions labeled with the same subscript must match or be</span>
<span class="sd"> able to be broadcasted. Repeated subscript labels in one input take the diagonal.</span>
<span class="sd"> Repeating a label across multiple inputs means that those axes will be</span>
<span class="sd"> multiplied. Omitting a label from the output means values along those axes</span>
<span class="sd"> will be summed. In implicit mode, the indices which appear once in the</span>
<span class="sd"> expression will be part of the output in increasing alphabetical order. In</span>
<span class="sd"> explicit mode, the output can be controlled by specifying output subscript</span>
<span class="sd"> labels by adding an arrow (-&gt;) followed by subscripts for the output. For</span>
<span class="sd"> example, “ij,jk-&gt;ik” is equivalent to “ij,jk”. Ellipsis (‘…’) can be used</span>
<span class="sd"> in place of subscripts to broadcast the dimensions. See the TensorRT</span>
<span class="sd"> Developer Guide for more details on equation syntax.</span>
<span class="sd"> Many common operations can be expressed using the Einsum equation. For</span>
<span class="sd"> example:</span>
<span class="sd"> Matrix Transpose: ij-&gt;ji</span>
<span class="sd"> Sum: ij-&gt; Matrix-Matrix</span>
<span class="sd"> Multiplication: ik,kj-&gt;ij</span>
<span class="sd"> Dot Product: i,i-&gt;</span>
<span class="sd"> Matrix-Vector Multiplication: ik,k-&gt;i</span>
<span class="sd"> Batch Matrix Multiplication: ijk,ikl-&gt;ijl</span>
<span class="sd"> Batch Diagonal: …ii-&gt;…i</span>
<span class="sd"> Note that TensorRT does not support ellipsis or diagonal operations so,</span>
<span class="sd"> neither, does TensorRT-LLM.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> einsum_eq : str</span>
<span class="sd"> The Einsum equation.</span>
<span class="sd"> inputs: Sequence[Tensor]</span>
<span class="sd"> The sequence of inputs consumed by the Einsum operation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the Einsum operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_einsum</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span>
<span class="n">einsum_eq</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="permute">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.permute">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dims</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to permute the dimensions of a tensor.</span>
<span class="sd"> The dimensions of the input tensor are permuted according to the sequence</span>
<span class="sd"> of dimensions in &#39;dims&#39;. That operation maps to tensorrt.IShuffleLayer where</span>
<span class="sd"> the second transposition is described by the indices in &#39;dims&#39;.</span>
<span class="sd"> Given a tensor of rank N, the result of the permutation is a tensor of rank</span>
<span class="sd"> N in which the i-th input dimension maps to the dims[i]-th dimension.</span>
<span class="sd"> For example, permute(input, [1, 0]) will transpose a 2D tensor by permuting</span>
<span class="sd"> the rows and columns.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to permute.</span>
<span class="sd"> dims : Sequence[int]</span>
<span class="sd"> The description of the permutation.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the permutation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dims</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="n">dims</span><span class="p">),</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="n">dims</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="transpose">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.transpose">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">transpose</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim0</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim1</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to transpose two dimensions of a tensor.</span>
<span class="sd"> That operation produces a tensor in which the dimensions &#39;dim0&#39; and &#39;dim1&#39;</span>
<span class="sd"> are permuted. The other dimensions, if the rank of the tensor is greater</span>
<span class="sd"> than 2, remain untouched.</span>
<span class="sd"> That function is a helper built on the &#39;functional.permute&#39; function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to transpose.</span>
<span class="sd"> dim0 : int</span>
<span class="sd"> The first dimension to transpose.</span>
<span class="sd"> dim1 : int</span>
<span class="sd"> The second dimension to transpose.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the permutation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">permutation</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
<span class="n">permutation</span><span class="p">[</span><span class="n">dim0</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim1</span>
<span class="n">permutation</span><span class="p">[</span><span class="n">dim1</span><span class="p">]</span> <span class="o">=</span> <span class="n">dim0</span>
<span class="k">return</span> <span class="n">permute</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">permutation</span><span class="p">)</span></div>
<div class="viewcode-block" id="view">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.view">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">view</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to create a view of a tensor.</span>
<span class="sd"> That operation adds a tensorrt.IShuffleLayer to the network. If the &#39;shape&#39;</span>
<span class="sd"> parameter is a Tensor, that view is dynamic. Otherwise, it is a static</span>
<span class="sd"> view.</span>
<span class="sd"> Note that TensorRT limits the number of inferred dimensions to 1. It means</span>
<span class="sd"> that the shape sequence or tensor cannot contain more than one -1. This</span>
<span class="sd"> function enforces that constraint and will assert if it is not respected.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to transpose.</span>
<span class="sd"> shape : Union[Tensor, Sequence[int]]</span>
<span class="sd"> The shape of the new tensor.</span>
<span class="sd"> zero_is_placeholder : bool</span>
<span class="sd"> When that parameter is True, the 0s in &#39;shape&#39; are replaced by the</span>
<span class="sd"> sizes of the corresponding dimensions from the &#39;input&#39;. Otherwise,</span>
<span class="sd"> the dimensions corresponding to 0s are shrunk.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the view/shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># TensorRT demands that at most one dimension is permitted to be specified as -1</span>
<span class="k">def</span><span class="w"> </span><span class="nf">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="nb">list</span><span class="p">):</span>
<span class="n">inferred_dim_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">list</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">inferred_dim_list</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">1</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">zero_is_placeholder</span> <span class="o">=</span> <span class="n">zero_is_placeholder</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
<span class="n">assert_no_more_than_one_inferred_dim</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">reshape_dims</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> is not supported&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">shape</span><span class="p">))</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="flatten">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.flatten">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">flatten</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">start_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">end_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Flattens input by reshaping it into a one-dimensional tensor.</span>
<span class="sd"> If start_dim or end_dim are passed, only dimensions starting with start_dim and</span>
<span class="sd"> ending with end_dim are flattened. The order of elements in input is unchanged.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to flatten.</span>
<span class="sd"> start_dim : int</span>
<span class="sd"> The first dim to flatten.</span>
<span class="sd"> end_dim : int</span>
<span class="sd"> The last dim to flatten.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the flatten layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">shape</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">shape</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">start_dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span> <span class="n">start_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="k">if</span> <span class="n">end_dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span> <span class="n">end_dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">):</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="k">if</span> <span class="n">end_dim</span> <span class="o">-</span> <span class="n">start_dim</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">flat_dim</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_dim</span><span class="p">,</span> <span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">flat_dim</span> <span class="o">*=</span> <span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">flat_dim</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">end_dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ndim</span><span class="p">):</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
<div class="viewcode-block" id="expand_dims">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">shape_cast_dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand the tensor shape with singleton dimensions.</span>
<span class="sd"> That function adds a tensorrt.IShuffleLayer to the network. Given an &#39;input&#39;</span>
<span class="sd"> of rank N and a sequence of M dimensions, the output tensor produced by</span>
<span class="sd"> this operation (when executed by TensorRT) will have a rank of N+M. Singleton</span>
<span class="sd"> dimensions will be inserted at the different positions in &#39;dim&#39;.</span>
<span class="sd"> The pseudo-code for that operation is:</span>
<span class="sd"> new_shape, ii = [], 0</span>
<span class="sd"> for jj in range(input.rank() + len(dim)):</span>
<span class="sd"> new_shape.append(1 if jj in dims else input.shape[ii++])</span>
<span class="sd"> For example, for a tensor of shape [3, 4, 1, 5]</span>
<span class="sd"> expand_dims(input, [0, 2])</span>
<span class="sd"> will produce a tensor of shape [1, 3, 1, 4, 1, 5].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to expand.</span>
<span class="sd"> dim : Union[int, Sequence[int]]</span>
<span class="sd"> The positions in the output tensor where to insert singleton</span>
<span class="sd"> dimensions.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
<span class="n">out_ndim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="o">=</span><span class="n">shape_cast_dtype</span><span class="p">)</span>
<span class="n">out_shapes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">out_ndim</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">out_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gather</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="n">j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shapes</span><span class="p">)</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></div>
<span class="c1"># NOTE: Jointly added with Apple</span>
<div class="viewcode-block" id="squeeze">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.squeeze">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">squeeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">zero_is_placeholder</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to remove singleton dimensions of a tensor.</span>
<span class="sd"> This functions creates an operation that removes singleton dimension</span>
<span class="sd"> (dimension of size 1) at positions &#39;dim&#39; in the input tensor. It works with</span>
<span class="sd"> negative values for the &#39;dim&#39;.</span>
<span class="sd"> For example, for a tensor &#39;input&#39; of shape [1, 4, 1, 4]:</span>
<span class="sd"> squeeze(input, 0) will produce an output of shape [4, 1, 4],</span>
<span class="sd"> squeeze(input, 2) will produce an output of shape [1, 4, 4],</span>
<span class="sd"> squeeze(input, [0, 2]) will produce an output of shape [4, 4],</span>
<span class="sd"> squeeze(input, [-2]) will produce an output of shape [1, 4, 4],</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor for which the singleton dimensions will be removed.</span>
<span class="sd"> dim : Union[int, Sequence[int]]</span>
<span class="sd"> The index of the singleton dimensions in the input tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">dim</span> <span class="o">=</span> <span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="p">)</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">s</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">):</span>
<span class="k">if</span> <span class="n">s</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="p">[]</span>
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="n">zero_is_placeholder</span><span class="p">)</span>
<span class="k">return</span> <span class="nb">input</span></div>
<div class="viewcode-block" id="unsqueeze">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unsqueeze">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">axis</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to insert a singleton dimension to a tensor.</span>
<span class="sd"> That functions creates an operation that insert a singleton dimension</span>
<span class="sd"> (dimension of size 1) at position &#39;axis&#39; in the output tensor. It works with</span>
<span class="sd"> negative values for the &#39;axis&#39;.</span>
<span class="sd"> For example, for a tensor &#39;input&#39; of shape [4, 4]:</span>
<span class="sd"> unsqueeze(input, 0) will produce an output of shape [1, 4, 4],</span>
<span class="sd"> unsqueeze(input, 1) will produce an output of shape [4, 1, 4],</span>
<span class="sd"> unsqueeze(input, -1) will produce an output of shape [4, 4, 1],</span>
<span class="sd"> unsqueeze(input, -2) will produce an output of shape [4, 1, 4],</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to expand with a singleton dimension.</span>
<span class="sd"> axis : int</span>
<span class="sd"> The index of the singleton dimension in the output tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">axis</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="o">+</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">axis</span><span class="p">)</span></div>
<div class="viewcode-block" id="stack">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.stack">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">stack</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to contact input tensors along a new dimension.</span>
<span class="sd"> The function creates an operation that creates a new dim for all the</span>
<span class="sd"> input tensors and then concatenates them along that new dim.</span>
<span class="sd">.</span>
<span class="sd"> All the tensors in &#39;inputs&#39; must have the same shape.</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> assert all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
<span class="sd"> The shape of the output tensor is defined as:</span>
<span class="sd"> output.rank() = inputs[0].rank() + 1</span>
<span class="sd"> output.shape[dim] = len(inputs)</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> if ii &lt; dim:</span>
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
<span class="sd"> else:</span>
<span class="sd"> output.shape[ii+1] = inputs[0].shape[ii]</span>
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
<span class="sd"> stack(inputs, 0)</span>
<span class="sd"> will produce [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] of shape [2, 2, 2] and</span>
<span class="sd"> stack(inputs, 1)</span>
<span class="sd"> will produce [[[0, 1], [4, 5]], [[2, 3], [6, 7]]] of shape [2, 2, 2].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> inputs : Sequence[Tensor]</span>
<span class="sd"> The sequence of tensors to stack.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension in which the stack is performed.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that contains the input tensors stacked along a new dimension.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">concat</span><span class="p">([</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span> <span class="k">for</span> <span class="n">inp</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="expand_dims_like">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_dims_like">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to expand the first tensor to the same rank as the second</span>
<span class="sd"> tensor.</span>
<span class="sd"> That function takes a first tensor. It also accepts an integer or a float,</span>
<span class="sd"> in which case it creates a constant tensor from it. In both cases, the rank</span>
<span class="sd"> of that first tensor is compared to the rank of the second tensor. If they</span>
<span class="sd"> are of the same rank, the first tensor is returned. Otherwise, the first</span>
<span class="sd"> tensor is expanded on the left to match the rank of the second tensor.</span>
<span class="sd"> Note that the shapes do not have to match, only the rank is considered in</span>
<span class="sd"> that function.</span>
<span class="sd"> For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the</span>
<span class="sd"> first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first tensor to expand. When a scalar value is provided as a</span>
<span class="sd"> parameter, that function first creates a tensor before expanding it</span>
<span class="sd"> (if needed).</span>
<span class="sd"> right : Tensor</span>
<span class="sd"> The reference tensor to match.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the shuffle layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp16_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">fp32_array</span><span class="p">([</span><span class="n">left</span><span class="p">]))</span>
<span class="n">left_ndim</span> <span class="o">=</span> <span class="n">left</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">right_ndim</span> <span class="o">=</span> <span class="n">right</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">right_ndim</span> <span class="o">&gt;</span> <span class="n">left_ndim</span><span class="p">:</span>
<span class="n">new_ndim</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">right_ndim</span> <span class="o">-</span> <span class="n">left_ndim</span><span class="p">))</span>
<span class="k">return</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">new_ndim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">left</span></div>
<span class="c1"># If dim is None, return a 1-D TensorRT LLM tensor of the size</span>
<span class="c1"># If dim is not None, return a 0-D TensorRT LLM tensor of the dimension size</span>
<div class="viewcode-block" id="shape">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.shape">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cast_to_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">clip_before_cast</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to create a shape tensor.</span>
<span class="sd"> The shape tensor can either be the shape of the input tensor when the</span>
<span class="sd"> parameter dim is None or a scalar (tensor of rank 0) that corresponds to</span>
<span class="sd"> the size of dim-th dimension.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor from which we want to extract the shape or the</span>
<span class="sd"> size in one dimension.</span>
<span class="sd"> dim : Optional[int]</span>
<span class="sd"> The dimension from which to extract the size. If it is None, the</span>
<span class="sd"> entire shape of the input tensor is returned.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that contains the shape of the input tensor (if &#39;dim&#39; is None)</span>
<span class="sd"> or the size in the dimension &#39;dim&#39; of the input tensor. If &#39;dim&#39; is</span>
<span class="sd"> &#39;None&#39;, that tensor has the same rank as the input tensor, otherwise</span>
<span class="sd"> its rank is 0.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shape</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cast_to_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">clip_before_cast</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span><span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="s1">&#39;int32&#39;</span>
<span class="ow">or</span> <span class="n">cast_to_dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">clip_before_cast</span>
<span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;This parameter only expects a tuple of 2 integers (lower, upper) but got </span><span class="si">{</span><span class="n">clip_before_cast</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">int_clip</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">clip_before_cast</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">res</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">cast_to_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">res</span>
<span class="k">return</span> <span class="n">gather</span><span class="p">(</span><span class="n">res</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([])</span></div>
<div class="viewcode-block" id="gather">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gather</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to gather elements from a tensor.</span>
<span class="sd"> That function implements the GatherElements operator from the ONNX</span>
<span class="sd"> specification as described in</span>
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements</span>
<span class="sd"> The input and indices arguments must have the same rank &gt;= 1. The operation</span>
<span class="sd"> will produce a tensor with the same shape as the indices tensor. The axis</span>
<span class="sd"> is the dimension to gather on.</span>
<span class="sd"> As shown in the ONNX description, for a 3D tensor, the output is:</span>
<span class="sd"> out[i][j][k] = input[indices[i][j][k]][j][k] if axis = 0,</span>
<span class="sd"> out[i][j][k] = input[i][indices[i][j][k]][k] if axis = 1,</span>
<span class="sd"> out[i][j][k] = input[i][j][indices[i][j][k]] if axis = 2.</span>
<span class="sd"> For example,</span>
<span class="sd"> gather([[4, 2], [5, 3]], 0, [[1, 0], [0, 1]])</span>
<span class="sd"> will produce [[5, 2], [4, 3]].</span>
<span class="sd"> gather([[1, 2, 3], [4, 5, 6], 1, [[1], [0]])</span>
<span class="sd"> will produce [[2], [4]]. See the ONNX documentation for more examples.</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to gather elements from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to gather on.</span>
<span class="sd"> indices : Union[Tensor, int]</span>
<span class="sd"> The positions in the &#39;dim&#39; dimension to gather from.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the gathered elements. It has the same shape as</span>
<span class="sd"> the indices tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">indices</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">indices</span><span class="p">]))</span>
<span class="c1"># The input and indices tensors must have the same rank.</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">indices</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="select">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.select">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to select a slice of elements from a tensor.</span>
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
<span class="sd"> index-th slice of elements in the dimension &#39;dim&#39; to create a new tensor.</span>
<span class="sd"> The output tensor has a shape in which the input dimension &#39;dim&#39; is</span>
<span class="sd"> removed.</span>
<span class="sd"> The &#39;index&#39; can either be an integer or a 1D tensor containing a single</span>
<span class="sd"> element.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> select(input, 0, 1)</span>
<span class="sd"> will create a tensor of shape [3] that contains the [2, 1, 2].</span>
<span class="sd"> Regarding the shape of the output tensor, the dimension &#39;dim&#39; is removed.</span>
<span class="sd"> It means that for a tensor of shape [4, 2, 6, 3],</span>
<span class="sd"> select(input, 2, 4)</span>
<span class="sd"> will select the 5th slice (index == 4) from the 3rd dimension (dim == 2)</span>
<span class="sd"> and return a tensor of shape [4, 2, 3] (i.e. the 3rd dimension is removed).</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to select from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to select from.</span>
<span class="sd"> index : Union[Tensor, int]</span>
<span class="sd"> The index of the slice in the &#39;dim&#39; dimension to select.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the selected slice.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">index</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">([</span><span class="n">index</span><span class="p">]))</span>
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">index</span><span class="o">.</span><span class="n">size</span><span class="p">(</span>
<span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
<div class="viewcode-block" id="index_select">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.index_select">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">index_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to select slices of elements from a tensor.</span>
<span class="sd"> Given an input tensor, that function creates an operation that selects the</span>
<span class="sd"> slices of elements in the dimension &#39;dim&#39; at the indices listed in &#39;index&#39;</span>
<span class="sd"> to create a new tensor. The output tensor has the same rank as the input</span>
<span class="sd"> tensor.</span>
<span class="sd"> The &#39;index&#39; is a tensor of rank 1.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> index_select(input, 0, [0, 1])</span>
<span class="sd"> will create a tensor of shape [2, 3] that contains the [[4, 2, 5], [2, 1, 2]].</span>
<span class="sd"> Regarding the shape of the output tensor, the dimension &#39;dim&#39; has the same</span>
<span class="sd"> size as the &#39;index&#39; tensor. It means that for a input tensor of shape [4, 2, 6, 3],</span>
<span class="sd"> index_select(input, 2, [1, 4])</span>
<span class="sd"> will select the 2nd and 5th slices (index == 1 or 4) from the 3rd dimension</span>
<span class="sd"> (dim == 2) and return a tensor of shape [4, 2, 2, 3] (i.e. the 3rd</span>
<span class="sd"> dimension is shrunk to 2).</span>
<span class="sd"> Note that this operation can also be used to expand a tensor in the &#39;dim&#39;</span>
<span class="sd"> dimension, for example, on input [[0, 1], [2, 3]],</span>
<span class="sd"> index_select(input, 1, [0, 0, 0])</span>
<span class="sd"> will produce a tensor of shape [2, 3] containing [[0, 0, 0], [2, 2, 2]].</span>
<span class="sd"> That operation maps to the TensorRT IGatherLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to select from.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to select from.</span>
<span class="sd"> index : Tensor</span>
<span class="sd"> The indices of the slices in the &#39;dim&#39; dimension to select.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the selected slices.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;index should have rank 1, got </span><span class="si">{</span><span class="n">index</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">new_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">index</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">new_shape</span><span class="p">))</span></div>
<span class="c1"># NOTE: Jointly added with Apple</span>
<div class="viewcode-block" id="scatter">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">updates</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> This operation adds a layer that creates an output tensor by element-wise</span>
<span class="sd"> copying values from the input tensor and then updating values by the given</span>
<span class="sd"> `indices` and `updates` tensors.</span>
<span class="sd"> For a 2D input tensor, it first copies the input to output,</span>
<span class="sd"> then updates the output tensor like the following for each entry in `updates`:</span>
<span class="sd"> output[indices[i][j]][j] = updates[i][j] if dim=0</span>
<span class="sd"> output[i][indices[i][j]] = updates[i][j] if dim=1</span>
<span class="sd"> If the `input` tensor is [[1, 2, 3], [4, 5, 6]],</span>
<span class="sd"> the indices tensor is [[1, 2], [0, 1]],</span>
<span class="sd"> the updates tensor is [[-1, -2], [-3, -4]], and dim=1</span>
<span class="sd"> the output tensor will be [[1, -1, -2], [-3, -4, 6]].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The input data that needs to be updated.</span>
<span class="sd"> dim: int</span>
<span class="sd"> The axis on which the scatter is to be performed.</span>
<span class="sd"> indices: Tensor</span>
<span class="sd"> An integer tensor of the same rank as input that indicates the positions to be updated.</span>
<span class="sd"> updates: Tensor</span>
<span class="sd"> A data tensor of same shape as the `indices` tensor that contains the update values.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor created by the element-wise scatter layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">updates</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ELEMENT</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="gather_nd">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_nd">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gather_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">indices</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">batch_dims</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Adds a layer that performs a gather with some element-wise dimensions.</span>
<span class="sd"> See: https://onnx.ai/onnx/operators/onnx__GatherND.html</span>
<span class="sd"> The gather is performed on dim=batch_dims.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The tensor on which the gather operation is performed.</span>
<span class="sd"> indices: Tensor</span>
<span class="sd"> The tensor that indicates which entries to be gathered.</span>
<span class="sd"> batch_dims: int</span>
<span class="sd"> The number of first dimensions that should be skipped before gather starts.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor created by the gather layer with GatherMode.ND.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">indices</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
<span class="n">gather_layer</span><span class="o">.</span><span class="n">num_elementwise_dims</span> <span class="o">=</span> <span class="n">batch_dims</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="nonzero">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.nonzero">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">nonzero</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Adds a layer that finds the indices of non-zero values of the input tensor.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The input tensor for which we need to find the indices of non-zero values.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor of shape [D, C] where D is the number of dimensions of `input` and</span>
<span class="sd"> C is the number of non-zero values in it.</span>
<span class="sd"> Each column of this 2D tensor represents the index tuple for each non-zero value.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">non_zero_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="masked_select">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_select">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">masked_select</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to select elements from a tensor according to a boolean</span>
<span class="sd"> mask tensor.</span>
<span class="sd"> Given an input tensor, that function creates an operation that selects</span>
<span class="sd"> elements at the indices indicated by the boolean mask tensor to create</span>
<span class="sd"> a new tensor. The output tensor is a 1-D tensor.</span>
<span class="sd"> The input tensor must have rank &gt;= 1. The shapes of the input tensor and</span>
<span class="sd"> the mask tensor dont need to match, but they must be able to be broadcasted.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> masked_select(input, [[True, False, True], [False, True, False], [True, False, True]])</span>
<span class="sd"> will create a tensor of shape [5] that contains the [4, 5, 1, 4, 1].</span>
<span class="sd"> masked_select(input, [[True], [False], [True]])</span>
<span class="sd"> will create a tensor of shape [6] that contains the [4, 2, 5, 4, 7, 1].</span>
<span class="sd"> masked_select(input, [[False, False, True]])</span>
<span class="sd"> will create a tensor of shape [3] that contains the [5, 2, 1].</span>
<span class="sd"> masked_select(input, [False])</span>
<span class="sd"> will create a tensor of shape [0] which is empty.</span>
<span class="sd"> That operation is implemented by NonZero, Shuffle and GatherV2 layers</span>
<span class="sd"> in TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to select from.</span>
<span class="sd"> mask : Tensor</span>
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The 1-D tensor containing the selected elements.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;input should have rank &gt;= 1&quot;</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">gather_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather_v2</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">GatherMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">gather_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">gather_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="cumsum">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cumsum">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cumsum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to calculate inclusive cumulative sum of elements of</span>
<span class="sd"> a tensor in a given dimension.</span>
<span class="sd"> Given an input tensor, that function creates an operation that calculates</span>
<span class="sd"> inclusive cumulative sum of elements in the dimension &#39;dim&#39; to create</span>
<span class="sd"> a new tensor. The output tensor has the same shape as the input tensor.</span>
<span class="sd"> The input tensor must have rank &gt;= 1. The &#39;dim&#39; must be valid, and negative</span>
<span class="sd"> value is supported.</span>
<span class="sd"> For example, on input=[[4, 2, 5], [2, 1, 2], [4, 7, 1]], which has a shape</span>
<span class="sd"> [3, 3],</span>
<span class="sd"> cumsum(input, 0)</span>
<span class="sd"> will produce [[4, 2, 5], [6, 3, 7], [10, 10, 8]].</span>
<span class="sd"> cumsum(input, 1)</span>
<span class="sd"> will produce [[4, 6, 11], [2, 3, 5], [4, 11, 12]].</span>
<span class="sd"> That operation is implemented by TensorRT ILoopLayer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor to calculate the inclusive cumulative sum.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension to calculate the inclusive cumulative sum. Negative</span>
<span class="sd"> value is supported.</span>
<span class="sd"> prefer_plugin : bool</span>
<span class="sd"> Whether to use the cumsumLastDim plugin if dim is last dim.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the inclusive cumulative sum of input.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;input should have rank &gt;= 1&quot;</span>
<span class="k">assert</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">&gt;=</span> <span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">(</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;dim should be in [</span><span class="si">{</span><span class="o">-</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">, </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">) when input have rank </span><span class="si">{</span><span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">if</span> <span class="n">prefer_plugin</span><span class="p">:</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span>
<span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
<span class="n">cumsum_last_dim_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span><span class="s1">&#39;CumsumLastDim&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">cumsum_last_dim_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">input_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;input_length&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">input_length</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">cumsum_last_dim_plug</span> <span class="o">=</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
<span class="s2">&quot;cumsum_last_dim&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
<span class="n">cumsum_last_dim_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">cumsum_last_dim_plg_creator</span><span class="p">,</span>
<span class="s2">&quot;cumsum_last_dim&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># credit to Apple</span>
<span class="n">reduction_length</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">reduction_range</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">,</span>
<span class="n">to_array</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
<span class="n">reduction_length</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;int64&#39;</span><span class="p">)</span>
<span class="n">lower_triangle</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="o">&lt;=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">reduction_range</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">lower_triangle</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">slice_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">slice_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span>
<span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">))])</span>
<span class="n">slice_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">slice_shape</span><span class="p">)</span>
<span class="n">zero_tensor</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">zero_tensor</span><span class="p">,</span> <span class="n">slice_shape</span><span class="p">)</span>
<span class="n">loop_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_loop</span><span class="p">()</span>
<span class="n">trip_limit</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span>
<span class="n">loop_layer</span><span class="o">.</span><span class="n">add_trip_limit</span><span class="p">(</span><span class="n">trip_limit</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TripLimit</span><span class="o">.</span><span class="n">COUNT</span><span class="p">)</span>
<span class="n">iterator_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_iterator</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="n">cur_slice</span> <span class="o">=</span> <span class="n">iterator_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">running_sum_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_recurrence</span><span class="p">(</span><span class="n">zero_tensor</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">running_sum</span> <span class="o">=</span> <span class="n">running_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">cur_sum_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span>
<span class="n">cur_slice</span><span class="p">,</span> <span class="n">running_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
<span class="n">cur_sum</span> <span class="o">=</span> <span class="n">cur_sum_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">running_sum_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">cur_sum</span><span class="p">)</span>
<span class="n">loop_output_layer</span> <span class="o">=</span> <span class="n">loop_layer</span><span class="o">.</span><span class="n">add_loop_output</span><span class="p">(</span>
<span class="n">cur_sum</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">LoopOutput</span><span class="o">.</span><span class="n">CONCATENATE</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="n">loop_output_layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">trip_limit</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">loop_output_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">loop_output_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="masked_scatter">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.masked_scatter">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">masked_scatter</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add the masked_scatter base on PyTorch definition.</span>
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html#torch-tensor-masked-scatter for a</span>
<span class="sd"> description of that function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> mask : Tensor</span>
<span class="sd"> The boolean mask tensor that indicates elements to select.</span>
<span class="sd"> source: Tensor</span>
<span class="sd"> The tensor to copy from</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor containing the source tensor selected by mask.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;input should have rank &gt;= 1&quot;</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
<span class="n">expanded_mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">))</span>
<span class="n">non_zero_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_non_zero</span><span class="p">(</span><span class="n">expanded_mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">shuffle_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_shuffle</span><span class="p">(</span><span class="n">non_zero_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">second_transpose</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">source</span> <span class="o">=</span> <span class="n">source</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">shuffle_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="concat">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.concat">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">concat</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to concatenate tensors.</span>
<span class="sd"> The function creates an operation that concatenates the tensors from the</span>
<span class="sd"> sequence &#39;inputs&#39;. The concatenation is done along the dimension &#39;dim&#39;.</span>
<span class="sd"> All the tensors in &#39;inputs&#39; must have the same shape expect for the</span>
<span class="sd"> dimension &#39;dim&#39;.</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> assert (ii == dim) or all(inp.shape[ii] == inputs[0].shape[ii] for inp in inputs)</span>
<span class="sd"> The shape of the output tensor is defined as:</span>
<span class="sd"> for ii in range(inputs[0].rank()):</span>
<span class="sd"> # Same size as all the inputs in dimension ii != dim.</span>
<span class="sd"> output.shape[ii] = inputs[0].shape[ii]</span>
<span class="sd"> # Sum of the sizes in the different inputs in dimension &#39;dim&#39;.</span>
<span class="sd"> if ii == dim:</span>
<span class="sd"> for jj in range(1, len(inputs)):</span>
<span class="sd"> output.shape[ii] += inputs[jj].shape[ii]</span>
<span class="sd"> For example, given a sequence of two 2D tensors [[0, 1], [2, 3]] and</span>
<span class="sd"> [[4, 5], [6, 7]] both of shape [2, 2],</span>
<span class="sd"> concat(inputs, 0)</span>
<span class="sd"> will produce [[0, 1], [2, 3], [4, 5], [6, 7]] of shape [4, 2] and</span>
<span class="sd"> concat(inputs, 1)</span>
<span class="sd"> will produce [[0, 1, 4, 5], [2, 3, 6, 7]] of shape [2, 4].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> inputs : Sequence[Union[Tensor, int]]</span>
<span class="sd"> The sequence of tensors to concatenate. For integers, that function</span>
<span class="sd"> creates constant tensors.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension in which the concatenation is performed.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that contains the concatenation of the tensors.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">inputs</span>
<span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Number of inputs (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span><span class="si">}</span><span class="s2">) to the concatenation layer must be &gt; 0.&quot;</span>
<span class="n">tmp</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="o">*</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="k">if</span> <span class="n">i</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">tmp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_concatenation</span><span class="p">([</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tmp</span><span class="p">])</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">tmp</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="softmax">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softmax">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute softmax on a tensor.</span>
<span class="sd"> That operation computes the softmax on the input tensor in the dimension</span>
<span class="sd"> &#39;dim&#39; if specified. Otherwise, it is applied on the last dimension.</span>
<span class="sd"> It inserts a ISoftmaxLayer to the TensorRT graph.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which to apply softmax.</span>
<span class="sd"> dim : Optional[int]</span>
<span class="sd"> The dimension used to apply softmax.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of the softmax layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">dim</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_softmax</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="k">def</span><span class="w"> </span><span class="nf">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">per_token_scale</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to perform lookup in a tensor.</span>
<span class="sd"> That operation performs the lookup needed by embedding layers. Given a</span>
<span class="sd"> &#39;weight&#39; tensor of shape [rows, cols], it produces a tensor of shape</span>
<span class="sd"> [inputs.size(0), cols] where the ith row corresponds to the input[i] row in</span>
<span class="sd"> the weight tensor.</span>
<span class="sd"> It inserts a IPluginV2Layer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor contains the indices to perform the lookup.</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The table to gather from.</span>
<span class="sd"> rank : int</span>
<span class="sd"> The mpi rank.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of the lookup layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Lookup&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">per_token_scale</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">rank</span><span class="p">])</span>
<span class="n">lookup_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;lookup&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">per_token_scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lookup_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;lookup&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<div class="viewcode-block" id="embedding">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.embedding">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">embedding</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">tp_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">tp_group</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">sharding_dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">per_token_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">padding</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to perform embedding lookup.</span>
<span class="sd"> That operation performs the embedding lookup. The &#39;input&#39; tensor contains</span>
<span class="sd"> the identifiers of the rows of &#39;weight&#39; to gather.</span>
<span class="sd"> 1. Distribute the embedding lookup table over multiple GPU</span>
<span class="sd"> When &#39;tp_size&#39; is greater than 1 and the &#39;tp_group&#39; is defined, this</span>
<span class="sd"> embedding lookup is distributed among multiple GPUs.</span>
<span class="sd"> When &#39;sharding_dim==0&#39;, each GPU stores a subset of the rows of the embedding</span>
<span class="sd"> table rows(that number of rows per GPU is given by weights.shape[0] and the offset to</span>
<span class="sd"> the 1st row stored on the GPU is given by rank * weights.shape[0]). Each</span>
<span class="sd"> parallel rank will query all the indices and set 0s for the weights that</span>
<span class="sd"> are not stored on the associated GPU. To compute the final result, a</span>
<span class="sd"> parallel all-reduce operation is added to the TensorRT graph. That lookup</span>
<span class="sd"> can be performed using either the plugin or the operators TensorRT support.</span>
<span class="sd"> When&#39;sharding_dim==1&#39;, each GPU stores a subset of the embedding table&#39;s columns.</span>
<span class="sd"> Each rank can obtain a portion of the embedding results.</span>
<span class="sd"> Then the embedding is collected using the all-gather operation.</span>
<span class="sd"> Related transposition operations are also used to obtain the final results.</span>
<span class="sd"> 2. Store embedding lookup table as a whole</span>
<span class="sd"> When &#39;tp_size&#39; is not greater than 1, the embedding lookup table will not</span>
<span class="sd"> be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,</span>
<span class="sd"> the operation is implemented using a plugin (without the all-reduce operation).</span>
<span class="sd"> Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor the contains the indices to perform the lookup.</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The table to gather from.</span>
<span class="sd"> tp_size : int</span>
<span class="sd"> The number of GPUs collaborating to perform that embedding.</span>
<span class="sd"> tg_group : Optional[List[int]]</span>
<span class="sd"> The group of world ranks participating in the all-reduce when</span>
<span class="sd"> tp_size &gt; 1.</span>
<span class="sd"> sharding_dim : int</span>
<span class="sd"> sharding_dim = 0 means that we shard the embedding table in vocab dim;</span>
<span class="sd"> sharding_dim = 1 means that we shard the embedding table in embedding dim.</span>
<span class="sd"> tp_rank : int</span>
<span class="sd"> The tensor parallelism rank. Used to calculate offset in TP on vocab dim.</span>
<span class="sd"> padding: Tensor</span>
<span class="sd"> Additional padding added to the end of the embedding table before feeding into gather op.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the embedding lookup layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Per token scale is only supported by lookup plugin so if per_token_scale is not None, we must use lookup plugin</span>
<span class="c1"># Otherwise, we prefer to use ootb</span>
<span class="n">use_lookup_plugin</span> <span class="o">=</span> <span class="n">per_token_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">padding</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">weight</span><span class="p">,</span> <span class="n">padding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">padded_weight</span> <span class="o">=</span> <span class="n">weight</span>
<span class="c1"># Distribute embedding lookup table across multiple GPU</span>
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">tp_group</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># TP on vocab_size dimension</span>
<span class="k">if</span> <span class="n">tp_rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Rank cannot be none for tensor parallelism on vocab dim&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">per_token_scale</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">shape_weight</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">shape_weight</span><span class="p">,</span> <span class="n">starts</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">tmp_input</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">vocab_size</span> <span class="o">*</span> <span class="n">tp_rank</span>
<span class="c1"># Identify the valid indices</span>
<span class="n">is_qualified</span> <span class="o">=</span> <span class="n">op_and</span><span class="p">(</span><span class="n">tmp_input</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tmp_input</span> <span class="o">&lt;</span> <span class="n">vocab_size</span><span class="p">)</span>
<span class="n">is_qualified_expand</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span>
<span class="p">[</span><span class="n">is_qualified</span><span class="o">.</span><span class="n">ndim</span><span class="p">()])</span>
<span class="c1"># Replace the invalid ones to zero</span>
<span class="n">placeholder_input</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified</span><span class="p">,</span> <span class="n">tmp_input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="c1"># Get the temporal results</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span>
<span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">placeholder_input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">tmp_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="c1"># Set zero for invalid results</span>
<span class="n">placeholder_tmp</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">placeholder</span> <span class="o">=</span> <span class="n">placeholder_tmp</span> <span class="o">-</span> <span class="n">placeholder_tmp</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">is_qualified_expand</span><span class="p">,</span> <span class="n">tmp_output</span><span class="p">,</span> <span class="n">placeholder</span><span class="p">)</span>
<span class="c1"># Use all reduce to collect the results</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allreduce</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">sharding_dim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># TP on hidden dimension</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="c1"># [dim0, local_dim] -&gt; [dim0 * tp_size, local_dim] --&gt; [dim0, local_dim * tp_size]</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">allgather</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tp_group</span><span class="p">,</span> <span class="n">gather_dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s1">&#39;Tensor Parallelism only support splitting Embedding lookup along hidden (sharding_dim==1) and vocab (sharding_dim==0) dimensionis&#39;</span>
<span class="p">)</span>
<span class="c1"># Store embedding lookup table as a whole</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">use_lookup_plugin</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_lookup_plugin</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
<span class="n">padded_weight</span><span class="p">,</span>
<span class="n">rank</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">per_token_scale</span><span class="o">=</span><span class="n">per_token_scale</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_gather</span><span class="p">(</span><span class="n">padded_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="constant_to_tensor_">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constant_to_tensor_">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">constant_to_tensor_</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span> <span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">to_array</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># deduce the type from the given value</span>
<span class="c1"># NOTE: bool is a subtype of int, so bool needs to be checked first</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">bool</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">bool</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">float32</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">array_fn_dict</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span> <span class="n">int64_array</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span> <span class="n">int32_array</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span> <span class="n">fp32_array</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span> <span class="n">fp16_array</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">:</span> <span class="n">bf16_array</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">bool</span><span class="p">:</span> <span class="n">bool_array</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">assert</span> <span class="n">dtype</span> <span class="ow">in</span> <span class="n">array_fn_dict</span>
<span class="k">return</span> <span class="n">constant</span><span class="p">(</span><span class="n">array_fn_dict</span><span class="p">[</span><span class="n">dtype</span><span class="p">]([</span><span class="nb">input</span><span class="p">]</span> <span class="k">if</span> <span class="n">to_array</span> <span class="k">else</span> <span class="nb">input</span><span class="p">))</span>
<span class="k">return</span> <span class="nb">input</span></div>
<div class="viewcode-block" id="constants_to_tensors_">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.constants_to_tensors_">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">constants_to_tensors_</span><span class="p">(</span>
<span class="o">*</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">...</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Helper function to create tensors from multiple inputs.</span>
<span class="sd"> For each inputs, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if any input is int64, it upcasts other</span>
<span class="sd"> integer inputs to int64.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> inputs : Tuple[Union[Tensor, int, float], ...]</span>
<span class="sd"> The inputs to create tensors from.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tuple of tensors.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">has_int64</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">and</span> <span class="p">(</span><span class="n">i</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="o">**</span><span class="mi">31</span> <span class="ow">or</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="o">-</span><span class="mi">2</span><span class="o">**</span><span class="mi">31</span><span class="p">)</span>\
<span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
<span class="n">has_int64</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">break</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">has_int64</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="ow">and</span> <span class="n">i</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="k">if</span> <span class="n">has_int64</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">result</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">i</span><span class="p">))</span>
<span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">result</span><span class="p">)</span></div>
<div class="viewcode-block" id="broadcast_helper">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.broadcast_helper">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Helper function to perform a broadcast.</span>
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
<span class="sd"> make sure its rank is the same as the larger one.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A pair of tensors of same rank.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">left</span><span class="p">)</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span>
<span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">==</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&gt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">left</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span></div>
<div class="viewcode-block" id="elementwise_binary">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.elementwise_binary">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">elementwise_binary</span><span class="p">(</span><span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span>
<span class="nb">float</span><span class="p">],</span> <span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an elementwise operation with two inputs.</span>
<span class="sd"> For each input, that function first creates a constant tensor if the input</span>
<span class="sd"> is an integer or a float. Then, if needed, it expands the smaller tensor to</span>
<span class="sd"> make sure its rank is the same as the larger one. Then, it performs the</span>
<span class="sd"> elementwise operation &#39;op&#39;.</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> add for op=trt.ElementWiseOperation.SUM</span>
<span class="sd"> sub for op=trt.ElementWiseOperation.SUB</span>
<span class="sd"> mul for op=trt.ElementWiseOperation.PROD</span>
<span class="sd"> div for op=trt.ElementWiseOperation.DIV</span>
<span class="sd"> floordiv for op=trt.ElementWiseOperation.FLOOR_DIV</span>
<span class="sd"> gt for op=trt.ElementWiseOperation.GREATER</span>
<span class="sd"> lt for op=trt.ElementWiseOperation.LESS</span>
<span class="sd"> op_and for op=trt.ElementWiseOperation.AND</span>
<span class="sd"> op_or for op=trt.ElementWiseOperation.OR</span>
<span class="sd"> eq for op=trt.ElementWiseOperation.EQUAL</span>
<span class="sd"> minimum for op=trt.ElementWiseOperation.MIN</span>
<span class="sd"> maximum for op=trt.ElementWiseOperation.MAX</span>
<span class="sd"> pow for op=trt.ElementWiseOperation.POW</span>
<span class="sd"> It is implemented using the IElementWiseLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> op : trt.ElementWiseOperation</span>
<span class="sd"> The binary operation to perform.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this elementwise operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span> <span class="ow">and</span> <span class="n">right</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">int64</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_elementwise</span><span class="p">(</span><span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">op</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="n">add</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
<span class="n">sub</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">SUB</span><span class="p">)</span>
<span class="n">mul</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
<span class="n">div</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">DIV</span><span class="p">)</span>
<span class="n">floordiv</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">FLOOR_DIV</span><span class="p">)</span>
<span class="n">gt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">GREATER</span><span class="p">)</span>
<span class="n">lt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">LESS</span><span class="p">)</span>
<span class="n">op_and</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">AND</span><span class="p">)</span>
<span class="n">op_or</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">OR</span><span class="p">)</span>
<span class="n">eq</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">EQUAL</span><span class="p">)</span>
<span class="n">minimum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
<span class="n">maximum</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">)</span>
<span class="nb">pow</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">POW</span><span class="p">)</span>
<span class="n">op_xor</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">elementwise_binary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ElementWiseOperation</span><span class="o">.</span><span class="n">XOR</span><span class="p">)</span>
<div class="viewcode-block" id="modulo">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.modulo">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">modulo</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> This function adds an element-wise modulo (x % y) operation for a given tensor.</span>
<span class="sd"> Since there is no TensorRT layer that can directly perform this,</span>
<span class="sd"> this function implements it using some of the basic operations.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor that represents (x % y) modulo operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="p">(</span><span class="n">x</span> <span class="o">//</span> <span class="n">y</span><span class="p">)</span> <span class="o">*</span> <span class="n">y</span></div>
<div class="viewcode-block" id="where">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.where">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">where</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">bool</span><span class="p">],</span> <span class="n">left</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span>
<span class="n">right</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a where (aka select or if-then-else) operation.</span>
<span class="sd"> Assuming the three input parameters have the same shape, that function creates</span>
<span class="sd"> the operation to compute a tensor of the same shape such that:</span>
<span class="sd"> for ii in range(mul(condition.shape)):</span>
<span class="sd"> output[ii] = left[ii] if condition[ii] else right[ii]</span>
<span class="sd"> For each input, that function first creates a constant tensor if the</span>
<span class="sd"> condition is boolean or the left/right input is an integer or a float.</span>
<span class="sd"> Then, if needed, it expands the smaller tensor to make sure its</span>
<span class="sd"> rank is the same as the larger one. Then, it performs the selection.</span>
<span class="sd"> It is implemented using the ISelectLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> condition : Union[Tensor, bool]</span>
<span class="sd"> The condition. If that input is a boolean, the function</span>
<span class="sd"> creates a constant tensor.</span>
<span class="sd"> left : Union[Tensor, int, float]</span>
<span class="sd"> The first input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> right : Union[Tensor, int, float]</span>
<span class="sd"> The second input. If that input is an integer or a float, the</span>
<span class="sd"> function creates a constant tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this where operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Convert to tensors.</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">constant_to_tensor_</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
<span class="n">left</span><span class="p">,</span> <span class="n">right</span> <span class="o">=</span> <span class="n">constants_to_tensors_</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">right</span><span class="p">)</span>
<span class="c1"># Find the tensor with the largest rank of the three.</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">condition</span>
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">left</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">left</span>
<span class="k">if</span> <span class="n">largest</span><span class="o">.</span><span class="n">rank</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">right</span><span class="o">.</span><span class="n">rank</span><span class="p">():</span>
<span class="n">largest</span> <span class="o">=</span> <span class="n">right</span>
<span class="c1"># Expand the tensors to match the largest one.</span>
<span class="k">if</span> <span class="n">condition</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">condition</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="k">if</span> <span class="n">left</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">left</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">left</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="k">if</span> <span class="n">right</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">largest</span><span class="p">:</span>
<span class="n">right</span> <span class="o">=</span> <span class="n">expand_dims_like</span><span class="p">(</span><span class="n">right</span><span class="p">,</span> <span class="n">largest</span><span class="p">)</span>
<span class="c1"># Insert the operation.</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_select</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">left</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">right</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="unary">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unary">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">unary</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an elementwise operation on a single input.</span>
<span class="sd"> The following closures are defined in functional.*:</span>
<span class="sd"> round for op=trt.UnaryOperation.ROUND</span>
<span class="sd"> sqrt for op=trt.UnaryOperation.SQRT</span>
<span class="sd"> exp for op=trt.UnaryOperation.EXP</span>
<span class="sd"> sin for op=trt.UnaryOperation.SIN</span>
<span class="sd"> cos for op=trt.UnaryOperation.COS</span>
<span class="sd"> abs for op=trt.UnaryOperation.ABS</span>
<span class="sd"> log for op=trt.UnaryOperation.LOG</span>
<span class="sd"> It is implemented using the IUnaryLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> op : trt.UnaryOperation</span>
<span class="sd"> The unary operation to perform.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this elementwise operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_unary</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">op</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="nb">round</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ROUND</span><span class="p">)</span>
<span class="n">sqrt</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SQRT</span><span class="p">)</span>
<span class="n">exp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">EXP</span><span class="p">)</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">SIN</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">COS</span><span class="p">)</span>
<span class="nb">abs</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">ABS</span><span class="p">)</span>
<span class="n">log</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">LOG</span><span class="p">)</span>
<span class="n">not_op</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">unary</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">UnaryOperation</span><span class="o">.</span><span class="n">NOT</span><span class="p">)</span>
<div class="viewcode-block" id="log_softmax">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.log_softmax">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">log_softmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> This function is equivalent of torch.nn.functional.log_softmax() i.e.</span>
<span class="sd"> it performs log(softmax(input)) in a safer and faster way.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The data tensor on which log_softmax to be computed.</span>
<span class="sd"> dim: int</span>
<span class="sd"> The dimension of the input tensor along which log_softmax will be computed.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor of same shape as input with log_softmax computed on the specified dim.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">x_max</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">-</span> <span class="n">x_max</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">-</span> <span class="n">log</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span></div>
<div class="viewcode-block" id="reduce">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">op</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an reduction operation to do along a dimension.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> op : trt.ReduceOperation</span>
<span class="sd"> The reduction operation to perform.</span>
<span class="sd"> Options: SUM, PROD, MAX, MIN, AVG</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the reduction is performed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_reduce</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">op</span><span class="p">,</span>
<span class="n">axes</span><span class="p">,</span>
<span class="n">keep_dims</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<span class="n">prod</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">PROD</span><span class="p">)</span>
<span class="nb">min</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">reduce</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">)</span>
<div class="viewcode-block" id="mean">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mean">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">mean</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the mean along a dimension.</span>
<span class="sd"> Computes the mean along the dimension &#39;dim&#39; of the input tensor.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the mean is computed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">AVG</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="max">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.max">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">max</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the max along a dimension.</span>
<span class="sd"> Computes the max along the dimension &#39;dim&#39; of the input tensor.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the mean is computed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="sum">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.sum">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">sum</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the sum along a dimension.</span>
<span class="sd"> Computes the sum along the dimension &#39;dim&#39; of the input tensor.</span>
<span class="sd"> It is implemented using the IReduceLayer from TensorRT.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which the mean is computed.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Is the dimension kept in the reduced tensor? When True the</span>
<span class="sd"> dimension is kept, it is removed from the shape otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this reduction operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ReduceOperation</span><span class="o">.</span><span class="n">SUM</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="n">keepdim</span><span class="p">)</span></div>
<div class="viewcode-block" id="identity">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.identity">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">identity</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an identity operation.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this identity operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">identity_plugin</span><span class="p">:</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_identity</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Identity&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">()</span>
<span class="n">id_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;identity&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">id_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;identity&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="argmax">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.argmax">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">argmax</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">keepdim</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an argmax operation.</span>
<span class="sd"> As explained in the ONNX documentation,</span>
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#argmax</span>
<span class="sd"> that function creates a layer computing the indices of the max elements of</span>
<span class="sd"> the input tensor&#39;s element along the provided dim. The resulting tensor</span>
<span class="sd"> has the same rank as the input if keepdims is True. If keepdims is False,</span>
<span class="sd"> then the resulting tensor has the reduced dimension pruned.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension in which to compute the argmax indices.</span>
<span class="sd"> keepdim : bool</span>
<span class="sd"> Do we keep the dimension along which the reduction is performed?</span>
<span class="sd"> Yes, if set to True, no otherwise.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by this argmax operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span>
<span class="mi">1</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">keepdim</span><span class="p">:</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()))</span>
<span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">dim</span><span class="p">:</span>
<span class="n">a</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
<span class="n">output_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">output_shape</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">indices</span><span class="p">)</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">new_shape</span><span class="p">)</span></div>
<div class="viewcode-block" id="gelu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gelu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a GELU operation.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="p">(</span>
<span class="n">tanh</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="mf">0.044715</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">3.0</span><span class="p">)))</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span></div>
<div class="viewcode-block" id="geglu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.geglu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">geglu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a Gated-GELU operation.</span>
<span class="sd"> That function takes a tensor, splits it into two halves along the last</span>
<span class="sd"> dimension, applies GELU to the second half and multiply the results. The</span>
<span class="sd"> behavior is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor on which the activation function is applied.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the activation layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">gelu</span><span class="p">(</span><span class="n">b</span><span class="p">)</span></div>
<div class="viewcode-block" id="quick_gelu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.quick_gelu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">quick_gelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">sigmoid</span><span class="p">(</span><span class="mf">1.702</span> <span class="o">*</span> <span class="n">x</span><span class="p">)</span></div>
<div class="viewcode-block" id="gegelu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gegelu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gegelu</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">limit</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1"># a, b = x[..., ::2], x[..., 1::2]</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">a_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">b_starts</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">shapes</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span>
<span class="p">])</span>
<span class="n">strides</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">a</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">b_starts</span><span class="p">,</span> <span class="n">shapes</span><span class="p">,</span> <span class="n">strides</span><span class="p">)</span>
<span class="k">if</span> <span class="n">limit</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="nb">float</span><span class="p">(</span><span class="o">-</span><span class="mf">1e20</span><span class="p">),</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">clip</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="n">limit</span><span class="p">,</span> <span class="n">beta</span><span class="o">=</span><span class="n">limit</span><span class="p">)</span>
<span class="c1"># C = B + 1</span>
<span class="n">const1</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span>
<span class="n">trt_dtype_to_str</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">const1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">b_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)])</span>
<span class="n">const1_arr</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">const1</span><span class="p">,</span> <span class="n">b_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">quick_gelu</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">b</span> <span class="o">+</span> <span class="n">const1_arr</span><span class="p">)</span></div>
<div class="viewcode-block" id="group_norm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.group_norm">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">group_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">):</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">num_groups</span><span class="p">,</span>
<span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">,</span>
<span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
<span class="n">x</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="c1"># instance norm</span>
<span class="n">w_shape</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_groups</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span>
<span class="n">instance_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
<span class="n">instance_bias</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">w_shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)))</span>
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="n">i</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">instance_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">instance_bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">axes_mask</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">num_channels</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">ndim</span><span class="p">)])</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">+</span> <span class="n">bias</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">y</span></div>
<div class="viewcode-block" id="softplus">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.softplus">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">softplus</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add the softplus activation base on PyTorch definition.</span>
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.nn.functional.softplus.html#torch-nn-functional-softplus for a</span>
<span class="sd"> description of that function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> Input TensorRT LLM Tensor.</span>
<span class="sd"> beta : float</span>
<span class="sd"> The parameter for softplus computation.</span>
<span class="sd"> threshold : float</span>
<span class="sd"> The threshold for reverting to the linear function when input * beta &gt; threshold</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor created by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">sf_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_activation</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ActivationType</span><span class="o">.</span><span class="n">SOFTPLUS</span><span class="p">)</span>
<span class="n">sf_layer</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">beta</span>
<span class="n">sf_layer</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
<span class="n">prod_tensor</span> <span class="o">=</span> <span class="nb">input</span> <span class="o">*</span> <span class="n">beta</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">prod_tensor</span> <span class="o">&gt;</span> <span class="n">threshold</span>
<span class="k">return</span> <span class="n">where</span><span class="p">(</span><span class="n">result</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">sf_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">sf_layer</span><span class="p">))</span></div>
<div class="viewcode-block" id="outer">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.outer">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">outer</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">vec2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to compute the outer product between two tensors.</span>
<span class="sd"> That operation creates an Einsum node.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first input tensor.</span>
<span class="sd"> vec2 : Tensor</span>
<span class="sd"> The second input tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor produced by this layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">return</span> <span class="n">einsum</span><span class="p">(</span><span class="s1">&#39;i,j-&gt;ij&#39;</span><span class="p">,</span> <span class="p">[</span><span class="nb">input</span><span class="p">,</span> <span class="n">vec2</span><span class="p">])</span></div>
<div class="viewcode-block" id="avg_pool2d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.avg_pool2d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">avg_pool2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">ceil_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">count_include_pad</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_pooling_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PoolingType</span><span class="o">.</span><span class="n">AVERAGE</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">)</span>
<span class="k">if</span> <span class="n">stride</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">stride</span> <span class="o">=</span> <span class="n">kernel_size</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="conv1d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv1d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">padding</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">input_shuffled</span> <span class="o">=</span> <span class="n">stack</span><span class="p">([</span><span class="nb">input</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Dims</span><span class="p">([</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="n">input_shuffled</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">noutput</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span>
<span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output_2d</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">output_1d</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">output_2d</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output_1d</span></div>
<div class="viewcode-block" id="conv2d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv2d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">conv2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">pre_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">post_padding</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
<span class="k">if</span> <span class="n">pre_padding</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">pre_padding</span> <span class="o">=</span> <span class="n">pre_padding</span>
<span class="k">if</span> <span class="n">post_padding</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">post_padding</span> <span class="o">=</span> <span class="n">post_padding</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="conv3d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv3d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">conv3d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">dilation</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document this function!</span>
<span class="c1">##</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="c1"># TRT requires the input of Conv3D layer to be 5-dimentional tensor.</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">5</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">stride</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">stride</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">stride</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">padding</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">padding</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">padding</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dilation</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="n">dilation</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">dilation</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">3</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_convolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="n">layer</span><span class="o">.</span><span class="n">dilation_nd</span> <span class="o">=</span> <span class="n">dilation</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="conv_transpose2d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.conv_transpose2d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">conv_transpose2d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stride</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">output_padding</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">dilation</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1">##</span>
<span class="c1">## TODO: Document that function!</span>
<span class="c1">##</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="nb">input</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">()</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">noutput</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">kernel_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">2</span><span class="p">],</span> <span class="n">weight</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">is_weight_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">weight</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">weight</span> <span class="o">=</span> <span class="n">weight</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_weight_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">is_bias_constant</span> <span class="o">=</span> <span class="p">(</span><span class="n">bias</span><span class="o">.</span><span class="n">producer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">type</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">LayerType</span><span class="o">.</span><span class="n">CONSTANT</span><span class="p">)</span>
<span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span><span class="o">.</span><span class="n">producer</span><span class="o">.</span><span class="n">weights</span> <span class="k">if</span> <span class="n">is_bias_constant</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">Weights</span><span class="p">()</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_deconvolution_nd</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">noutput</span><span class="p">,</span>
<span class="n">kernel_size</span><span class="p">,</span> <span class="n">weight</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">stride_nd</span> <span class="o">=</span> <span class="n">stride</span>
<span class="n">layer</span><span class="o">.</span><span class="n">padding_nd</span> <span class="o">=</span> <span class="n">padding</span>
<span class="n">layer</span><span class="o">.</span><span class="n">num_groups</span> <span class="o">=</span> <span class="n">groups</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">is_weight_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_bias_constant</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span>
<span class="n">output</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)]))</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="split">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.split">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">split_size_or_sections</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
<span class="sd"> tensor by slicing it along the dimension &#39;dim&#39;. If &#39;split_size_or_sections&#39;</span>
<span class="sd"> is an integer, the tensor is split into &#39;input.shape[dim] /</span>
<span class="sd"> split_size_or_sections&#39; slices. If &#39;split_size_or_sections&#39; is a list of</span>
<span class="sd"> sizes, the tensor is split into &#39;len(split_size_or_sections)&#39; slices and</span>
<span class="sd"> the size of the ith slice is given by &#39;split_size_or_sections[i]&#39;.</span>
<span class="sd"> There are several constraints with the current implementation:</span>
<span class="sd"> - The input tensor must be static (no dynamic dimension),</span>
<span class="sd"> - If &#39;split_size_or_sections&#39; is an integer, the number of elements in</span>
<span class="sd"> the &#39;dim&#39; dimension of the input must be a multiple of</span>
<span class="sd"> &#39;split_size_or_sections&#39;: &#39;input.shape[dim] % split_size_or_sections == 0&#39;.</span>
<span class="sd"> - If &#39;split_size_or_sections&#39; is a sequence, the sum of the elements in</span>
<span class="sd"> &#39;split_size_or_sections&#39; must be equal to the size in the dimension</span>
<span class="sd"> &#39;dim&#39;: &#39;input.shape[dim] == sum(ii for ii in split_size_or_sections)&#39;.</span>
<span class="sd"> That operation is implemented using a &#39;slice&#39; operation for each output</span>
<span class="sd"> slice.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor to slice.</span>
<span class="sd"> split_size_or_sections : Union[int, Sequence[int]]</span>
<span class="sd"> If it is an integer, it encodes the size of each slice. Otherwise,</span>
<span class="sd"> if it is a sequence, it is the size of each slice.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension of the tensor to slice.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The list of tensors produced by the different operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
<span class="c1"># TODO: support non-divisible cases</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">split_size_or_sections</span> <span class="o">==</span> <span class="mi">0</span>
<span class="n">num_sections</span> <span class="o">=</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">split_size_or_sections</span>
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">]))</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span> <span class="o">*</span> <span class="n">i</span><span class="p">]))</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">outputs</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">total_size</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">split_size_or_sections</span><span class="p">:</span>
<span class="n">total_size</span> <span class="o">+=</span> <span class="n">i</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">==</span> <span class="n">total_size</span>
<span class="n">num_sections</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">split_size_or_sections</span><span class="p">)</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_sections</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">starts</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">+</span> <span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
<span class="n">sizes</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="n">split_size_or_sections</span><span class="p">[</span><span class="n">i</span><span class="p">]]))</span>
<span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">outputs</span></div>
<div class="viewcode-block" id="chunk">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.chunk">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">chunks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that splits a tensor into sub-tensors.</span>
<span class="sd"> This operation creates a list of tensors that are obtained from the input</span>
<span class="sd"> tensor by chunking it along the dimension &#39;dim&#39;. It produces &#39;chunks&#39;</span>
<span class="sd"> sub-tensors.</span>
<span class="sd"> That operation is only defined for static tensors (no dynamic dimension)</span>
<span class="sd"> and the size of the tensor in the dimension &#39;dim&#39; must be a multiple of</span>
<span class="sd"> &#39;chunks&#39;: &#39;input.shape[dim] % chunks == 0&#39;.</span>
<span class="sd"> It maps to &#39;split&#39; with &#39;split_size = input.shape[dim] / chunks&#39;.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor to slice.</span>
<span class="sd"> chunks : int</span>
<span class="sd"> The number of slices to split the input tensor into.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension of the tensor to slice.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The list of tensors produced by the different operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="ow">not</span> <span class="n">tensor</span><span class="o">.</span><span class="n">is_dynamic</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">dim</span> <span class="o">+=</span> <span class="n">ndim</span>
<span class="n">dim_value</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="n">dim</span><span class="p">]</span>
<span class="k">assert</span> <span class="n">dim_value</span> <span class="o">%</span> <span class="n">chunks</span> <span class="o">==</span> <span class="mi">0</span>
<span class="k">return</span> <span class="n">split</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim_value</span> <span class="o">//</span> <span class="n">chunks</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span></div>
<div class="viewcode-block" id="unbind">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.unbind">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">unbind</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Removes a tensor dimension.</span>
<span class="sd"> Returns a tuple of all slices along a given dimension, already without it.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="n">output_shape</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="n">dim</span><span class="p">]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">output</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">output_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">output</span> <span class="ow">in</span> <span class="n">outputs</span><span class="p">]</span></div>
<div class="viewcode-block" id="AllReduceStrategy">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceStrategy">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceStrategy</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">NCCL</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">MIN_LATENCY</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">UB</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">ONESHOT</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">TWOSHOT</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">LOWPRECISION</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">MNNVL</span> <span class="o">=</span> <span class="mi">7</span>
<span class="n">NCCL_SYMMETRIC</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">SYMM_MEM</span> <span class="o">=</span> <span class="mi">9</span> <span class="c1"># PyTorch symmetric memory with MULTIMEM</span></div>
<div class="viewcode-block" id="AllReduceFusionOp">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceFusionOp">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceFusionOp</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">NONE</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">LAST_PROCESS_FOR_UB</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">RESIDUAL_RMS_PREPOST_NORM</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">RESIDUAL_RMS_NORM_QUANT_FP8</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">RESIDUAL_RMS_NORM_OUT_QUANT_FP8</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4</span> <span class="o">=</span> <span class="mi">7</span>
<span class="n">MOE_FINALIZE_ALLREDUCE_RESIDUAL_RMS_NORM</span> <span class="o">=</span> <span class="mi">8</span></div>
<div class="viewcode-block" id="AllReduceParams">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">AllReduceParams</span><span class="p">():</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">strategy</span><span class="p">:</span> <span class="n">AllReduceStrategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span><span class="p">,</span>
<span class="n">fusion_op</span><span class="p">:</span> <span class="n">AllReduceFusionOp</span> <span class="o">=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">,</span>
<span class="n">enable_allreduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">trigger_completion_at_end</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">strategy</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">=</span> <span class="n">fusion_op</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">bias</span>
<span class="bp">self</span><span class="o">.</span><span class="n">residual</span> <span class="o">=</span> <span class="n">residual</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="o">=</span> <span class="n">norm_weight</span>
<span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scale</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span> <span class="o">=</span> <span class="n">norm_pre_residual_weight</span>
<span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
<span class="c1"># For torch path only, has no effect on TRT path</span>
<span class="bp">self</span><span class="o">.</span><span class="n">enable_allreduce</span> <span class="o">=</span> <span class="n">enable_allreduce</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trigger_completion_at_end</span> <span class="o">=</span> <span class="n">trigger_completion_at_end</span>
<span class="k">assert</span> <span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="o">.</span><span class="n">value</span> <span class="ow">or</span> <span class="p">(</span><span class="n">residual</span>
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
<div class="viewcode-block" id="AllReduceParams.has_affine">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_affine">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">has_affine</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
<div class="viewcode-block" id="AllReduceParams.has_bias">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_bias">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">has_bias</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
<div class="viewcode-block" id="AllReduceParams.has_scale">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.has_scale">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">has_scale</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">0</span></div>
<div class="viewcode-block" id="AllReduceParams.update_strategy">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.AllReduceParams.update_strategy">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update_strategy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">AUTO</span> <span class="ow">and</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span></div>
</div>
<div class="viewcode-block" id="MoEAllReduceParams">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MoEAllReduceParams">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">MoEAllReduceParams</span><span class="p">(</span><span class="n">AllReduceParams</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">device_num_experts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">expert_scale_factor</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">expanded_idx_to_permuted_idx</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">shared_expert_output</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">residual</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">norm_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">norm_pre_residual_weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">,</span>
<span class="n">enable_allreduce</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">is_cutlass_min_latency</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
<span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
<span class="n">residual</span><span class="o">=</span><span class="n">residual</span><span class="p">,</span>
<span class="n">norm_weight</span><span class="o">=</span><span class="n">norm_weight</span><span class="p">,</span>
<span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">,</span>
<span class="n">norm_pre_residual_weight</span><span class="o">=</span><span class="n">norm_pre_residual_weight</span><span class="p">,</span>
<span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span>
<span class="n">enable_allreduce</span><span class="o">=</span><span class="n">enable_allreduce</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device_num_experts</span> <span class="o">=</span> <span class="n">device_num_experts</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expert_scale_factor</span> <span class="o">=</span> <span class="n">expert_scale_factor</span>
<span class="bp">self</span><span class="o">.</span><span class="n">expanded_idx_to_permuted_idx</span> <span class="o">=</span> <span class="n">expanded_idx_to_permuted_idx</span>
<span class="bp">self</span><span class="o">.</span><span class="n">shared_expert_output</span> <span class="o">=</span> <span class="n">shared_expert_output</span>
<span class="bp">self</span><span class="o">.</span><span class="n">is_cutlass_min_latency</span> <span class="o">=</span> <span class="n">is_cutlass_min_latency</span>
<div class="viewcode-block" id="MoEAllReduceParams.is_valid">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.MoEAllReduceParams.is_valid">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_valid</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_cutlass_min_latency</span><span class="p">:</span>
<span class="k">return</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device_num_experts</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_expert_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">expanded_idx_to_permuted_idx</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span></div>
</div>
<div class="viewcode-block" id="create_allreduce_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.create_allreduce_plugin">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">create_allreduce_plugin</span><span class="p">(</span>
<span class="n">network</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">INetworkDefinition</span><span class="p">,</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">,</span>
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">ITensor</span><span class="p">],</span>
<span class="n">group</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">,</span>
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">AllReduceParams</span><span class="p">,</span>
<span class="p">):</span>
<span class="n">allreduce_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;AllReduce&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">allreduce_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">pf_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;group&quot;</span><span class="p">,</span> <span class="n">group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_dtype</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="p">[</span><span class="n">pf_group</span><span class="p">,</span> <span class="n">pf_dtype</span><span class="p">]</span>
<span class="n">p_strategy</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;strategy&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_strategy</span><span class="p">)</span>
<span class="n">p_fusion_op</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;fusion_op&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_fusion_op</span><span class="p">)</span>
<span class="n">p_eps</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;eps&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">float</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">eps</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_eps</span><span class="p">)</span>
<span class="n">p_affine</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;affine&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_affine</span><span class="p">)</span>
<span class="n">p_bias</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;bias&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_bias</span><span class="p">)</span>
<span class="n">p_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;scale&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">())],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p_scale</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">pfc</span><span class="p">)</span>
<span class="n">ar_plug</span> <span class="o">=</span> <span class="n">allreduce_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;allreduce&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">{</span>
<span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span><span class="p">,</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">,</span>
<span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL_SYMMETRIC</span>
<span class="p">}:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">workspace</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_bias</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">residual</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_affine</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_PREPOST_NORM</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">norm_pre_residual_weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">plug_inputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">all_reduce_params</span><span class="o">.</span><span class="n">scale</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">ar_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span></div>
<span class="n">allreduce_ub_counter</span> <span class="o">=</span> <span class="mi">0</span>
<div class="viewcode-block" id="allreduce">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allreduce">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">allreduce</span><span class="p">(</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">all_reduce_params</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AllReduceParams</span><span class="p">]</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a collective all-reduce.</span>
<span class="sd"> Let&#39;s define &#39;world_size&#39; as the length of the &#39;group&#39; list. That functions</span>
<span class="sd"> creates a layer to compute the sum of &#39;world_size&#39; tensors distributed</span>
<span class="sd"> amongst the &#39;world_size&#39; participating ranks (one GPU per rank).</span>
<span class="sd"> The list &#39;group&#39; contains the identifiers of the ranks participating into</span>
<span class="sd"> the collective operation.</span>
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the output</span>
<span class="sd"> tensor will have that same shape. The output tensor will be replicated on</span>
<span class="sd"> the &#39;world_size&#39; ranks.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-reduce</span>
<span class="sd"> collective operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> group : List[int]</span>
<span class="sd"> The ranks participating into the all-reduce operation.</span>
<span class="sd"> strategy: AllReduceStrategy</span>
<span class="sd"> NCCL delegates all-reduce to NCCL while ONESHOT and TWOSHOT are custom latency-optimal algorithms.</span>
<span class="sd"> AUTO chooses amongst the three based on a message-size heuristic.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">global</span> <span class="n">allreduce_ub_counter</span>
<span class="n">allreduce_ub_counter</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">all_reduce_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">all_reduce_params</span> <span class="o">=</span> <span class="n">AllReduceParams</span><span class="p">()</span>
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">update_strategy</span><span class="p">()</span>
<span class="c1"># TODO(TRTLLM-996): remove this WAR when custom allreduce is supported</span>
<span class="c1"># for encoder models in C++ runtime.</span>
<span class="n">workspace</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">!=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
<span class="k">if</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">=</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">NCCL_SYMMETRIC</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">workspace</span> <span class="o">=</span> <span class="n">current_all_reduce_helper</span><span class="p">()</span><span class="o">.</span><span class="n">workspace</span><span class="o">.</span><span class="n">trt_tensor</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">&quot;allreduce_ub_0_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="n">pfc</span> <span class="o">=</span> <span class="n">create_allreduce_plugin</span><span class="p">(</span>
<span class="n">network</span><span class="o">=</span><span class="n">default_trtnet</span><span class="p">(),</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">workspace</span><span class="o">=</span><span class="n">workspace</span><span class="p">,</span>
<span class="n">group</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">all_reduce_params</span><span class="o">=</span><span class="n">all_reduce_params</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allreduce_plg_creator</span><span class="p">,</span> <span class="s2">&quot;allreduce&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">!=</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="n">inter_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span> <span class="ow">and</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">(</span>
<span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span><span class="p">:</span>
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">strategy</span> <span class="o">==</span> <span class="n">AllReduceStrategy</span><span class="o">.</span><span class="n">UB</span><span class="p">:</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">has_scale</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">final_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">&quot;allreduce_ub_1_&quot;</span> <span class="o">+</span>
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
<span class="k">if</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">RESIDUAL_RMS_NORM_QUANT_NVFP4</span><span class="p">:</span>
<span class="n">scale_factor</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">&quot;allreduce_ub_2_&quot;</span> <span class="o">+</span>
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
<span class="k">return</span> <span class="p">(</span><span class="n">final_output</span><span class="p">,</span> <span class="n">scale_factor</span><span class="p">),</span> <span class="n">inter_output</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">all_reduce_params</span><span class="o">.</span><span class="n">fusion_op</span> <span class="o">==</span> <span class="n">AllReduceFusionOp</span><span class="o">.</span><span class="n">LAST_PROCESS_FOR_UB</span>
<span class="n">inter_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="s2">&quot;allreduce_ub_1_&quot;</span> <span class="o">+</span>
<span class="nb">str</span><span class="p">(</span><span class="n">allreduce_ub_counter</span><span class="p">))</span>
<span class="k">return</span> <span class="n">final_output</span><span class="p">,</span> <span class="n">inter_output</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">final_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">final_output</span></div>
<div class="viewcode-block" id="allgather">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.allgather">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">allgather</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">gather_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a collective all-gather.</span>
<span class="sd"> Let&#39;s define &#39;group_size&#39; as the length of the &#39;group&#39; list. That functions</span>
<span class="sd"> creates a layer to gather &#39;group_size&#39; tensors distributed</span>
<span class="sd"> amongst the &#39;group_size&#39; participating ranks (one GPU per rank).</span>
<span class="sd"> The list &#39;group&#39; contains the identifiers of the ranks participating into</span>
<span class="sd"> the collective operation.</span>
<span class="sd"> Note that &#39;group&#39; here can be either TP group or PP group, because allgather communication is not limited to a specific split pattern. Therefore &#39;group_size&#39; does not need to equal MPI &#39;world_size&#39;.</span>
<span class="sd"> The tensors in the different ranks must be 1D tensors (or views) and the</span>
<span class="sd"> output tensor will have that same shape.</span>
<span class="sd"> Given the &#39;section_size = input.shape[0] / group_size&#39;, each rank</span>
<span class="sd"> contributes a section of its input tensor that correspond to</span>
<span class="sd"> &#39;rank*section_size:(rank+1)*section_size&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL all-gather</span>
<span class="sd"> collective operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> group : List[int]</span>
<span class="sd"> The ranks participating into the all-gather operation.</span>
<span class="sd"> gather_dim: int = 0</span>
<span class="sd"> Gather along given dimension. By default 0, i.e. treated as 1D tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">allgather_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;AllGather&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">allgather_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">group_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">group</span><span class="p">)</span>
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;group&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">allgather</span> <span class="o">=</span> <span class="n">allgather_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;allgather&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">allgather</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">allgather_plg_creator</span><span class="p">,</span> <span class="s2">&quot;allgather&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="c1"># gather along a given dimension other than dim0</span>
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="c1"># also support -1 type of dim representation</span>
<span class="k">if</span> <span class="n">gather_dim</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">gather_dim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">+</span> <span class="n">gather_dim</span>
<span class="c1"># plugin above gathers as 1D flattened tensor</span>
<span class="c1"># 1. [dim0, ...dimi, ...dimN] -&gt; [group_size * dim0, ...dimi, ...dimN]</span>
<span class="c1"># now we need to gather-by-dim via split-concat</span>
<span class="c1"># 2. [group_size * dim0, ...dimi, ...dimN] -&gt; [dim0, ...group_size * dimi, ...dimN]</span>
<span class="c1"># 2.1 split</span>
<span class="n">split_size</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">group_size</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">starts</span> <span class="o">=</span> <span class="p">[</span><span class="n">constant</span><span class="p">(</span><span class="n">dims_array</span><span class="p">([</span><span class="mi">0</span><span class="p">]))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">d</span><span class="p">)</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span><span class="p">)]</span>
<span class="n">sizes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span>
<span class="n">sections</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">group_size</span><span class="p">):</span>
<span class="n">starts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">split_size</span> <span class="o">*</span> <span class="n">i</span>
<span class="n">sections</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">(</span><span class="n">starts</span><span class="p">),</span> <span class="n">concat</span><span class="p">(</span><span class="n">sizes</span><span class="p">)))</span>
<span class="c1"># 2.2 concat</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">sections</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="n">gather_dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="reduce_scatter">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.reduce_scatter">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">reduce_scatter</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">plg_creater</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;ReduceScatter&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creater</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;group&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">group</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">reduce_scatter_plug</span> <span class="o">=</span> <span class="n">plg_creater</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;reduce_scatter&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">reduce_scatter_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creater</span><span class="p">,</span> <span class="s2">&quot;reduce_scatter&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="send">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.send">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">send</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a send from a rank to another.</span>
<span class="sd"> The send operation sends a tensor from one rank to another. If a rank &#39;i&#39;</span>
<span class="sd"> sends a tensor to a rank &#39;j&#39;, the rank &#39;j&#39; must have a corresponding &#39;recv&#39;</span>
<span class="sd"> operation from rank &#39;i&#39;. See &#39;recv&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL send</span>
<span class="sd"> point-to-point operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> tgt : int</span>
<span class="sd"> The rank that receives the tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">send_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Send&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">send_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">tgt</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tgt_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tgt</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">tgt</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">send_plug</span> <span class="o">=</span> <span class="n">send_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;send&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">send_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">send_plg_creator</span><span class="p">,</span> <span class="s2">&quot;send&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="recv">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.recv">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">recv</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">src</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs a recv to a rank from another.</span>
<span class="sd"> The recv operation receives a tensor from on a rank from another. If a rank &#39;i&#39;</span>
<span class="sd"> receives a tensor from a rank &#39;j&#39;, the rank &#39;j&#39; must have a corresponding &#39;send&#39;</span>
<span class="sd"> operation to rank &#39;j&#39;. See &#39;send&#39;.</span>
<span class="sd"> That operation is implemented using a plugin that wraps the NCCL recv</span>
<span class="sd"> point-to-point operation. See</span>
<span class="sd"> https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv</span>
<span class="sd"> for details.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> src : int</span>
<span class="sd"> The rank that sends the tensor to.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">recv_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Recv&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">recv_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">src</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;src_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">src</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">recv_plug</span> <span class="o">=</span> <span class="n">recv_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;recv&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">recv_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">recv_plg_creator</span><span class="p">,</span> <span class="s2">&quot;recv&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="gemm_allreduce">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gemm_allreduce">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gemm_allreduce</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">b</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">output_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">fp8_inputs_override</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">a_sf</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">b_sf</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs fused GEMM+AllReduce.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> a: Tensor</span>
<span class="sd"> Input tensor A</span>
<span class="sd"> b: Tensor</span>
<span class="sd"> Input tensor B</span>
<span class="sd"> a_sf: Optional[Tensor]</span>
<span class="sd"> Input tensor for scaling input A</span>
<span class="sd"> b_sf: Optional[Tensor]</span>
<span class="sd"> Input tensor for scaling input B</span>
<span class="sd"> group: List[int]</span>
<span class="sd"> Ranks participating in collective</span>
<span class="sd"> transa: bool</span>
<span class="sd"> Whether or not input tensor A is transposed</span>
<span class="sd"> transb: bool</span>
<span class="sd"> Whether or not input tensor B is transposed</span>
<span class="sd"> alpha: float</span>
<span class="sd"> Alpha for GEMM -&gt; beta * C + (alpha * acc)</span>
<span class="sd"> output_dtype: trt.DataType</span>
<span class="sd"> Output type for plugin. If it is None, we</span>
<span class="sd"> will use type set in plugin_config.</span>
<span class="sd"> fp8_inputs_override: bool</span>
<span class="sd"> TRT graph does not detect FP8 inputs correctly. This</span>
<span class="sd"> flag is used to override the derived input tensor</span>
<span class="sd"> types so that our plugin knows to issue FP8 MMAs.</span>
<span class="sd"> Returns:</span>
<span class="sd"> Returns GEMM output tensor which has been reduced across ranks.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># Output tensor needs to be bound to externally managed</span>
<span class="c1"># memory so keep track of layer index so we can assign</span>
<span class="c1"># output tensor unique label.</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">gemm_allreduce</span><span class="p">,</span> <span class="s1">&#39;layer_idx&#39;</span><span class="p">):</span>
<span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Check inputs</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="k">if</span> <span class="n">fp8_inputs_override</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span> <span class="o">==</span> <span class="mi">1</span>
<span class="p">),</span> <span class="s2">&quot;`alpha` must be passed as a float32 ndarray if `fp8_inputs_override` is enabled for gemm_allreduce_plugin&quot;</span>
<span class="k">assert</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
<span class="k">assert</span> <span class="n">b</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span>
<span class="k">if</span> <span class="n">output_dtype</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">output_dtype</span> <span class="o">=</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_allreduce_plugin</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">output_dtype</span> <span class="ow">in</span> <span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">]</span>
<span class="n">alpha_is_tensor</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span>
<span class="k">if</span> <span class="n">alpha</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">alpha_is_tensor</span><span class="p">:</span>
<span class="n">alpha_value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">alpha_value</span> <span class="o">=</span> <span class="n">alpha</span>
<span class="n">plugin_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;GemmAllReduce&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plugin_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">trt_type_a</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span> <span class="k">if</span> <span class="n">fp8_inputs_override</span> <span class="k">else</span> <span class="n">a</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">trt_type_b</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span> <span class="k">if</span> <span class="n">fp8_inputs_override</span> <span class="k">else</span> <span class="n">b</span><span class="o">.</span><span class="n">dtype</span>
<span class="c1"># create plugin fields</span>
<span class="n">field_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;type_a&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">trt_type_a</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;type_b&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">trt_type_b</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;type_d&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">output_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;transa&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;transb&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;group&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;has_sfa&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">a_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;has_sfb&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">b_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;alpha_is_ptr&#39;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">alpha_is_tensor</span><span class="p">)],</span>
<span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">))</span>
<span class="n">field_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s1">&#39;alpha&#39;</span><span class="p">,</span> <span class="n">alpha_value</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">))</span>
<span class="c1"># create plugin</span>
<span class="n">fields</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span><span class="n">field_list</span><span class="p">)</span>
<span class="n">plugin</span> <span class="o">=</span> <span class="n">plugin_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;gemm_allreduce&quot;</span><span class="p">,</span> <span class="n">fields</span><span class="p">)</span>
<span class="c1"># define symbolic input tensors.</span>
<span class="c1"># note this does NOT allocate memory.</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">b</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">a_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">a_sf</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">b_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">b_sf</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="k">if</span> <span class="n">alpha_is_tensor</span><span class="p">:</span>
<span class="n">inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alpha</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">plugin</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plugin_creator</span><span class="p">,</span> <span class="s2">&quot;gemm_allreduce&quot;</span><span class="p">,</span> <span class="n">fields</span><span class="p">)</span>
<span class="c1"># define symbolic output tensors</span>
<span class="c1"># both output tensors point to same physical memory but</span>
<span class="c1"># one has unicast address and other has multicast address</span>
<span class="n">uc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">mc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">ipc_output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">uc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">mc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">ipc_output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="c1"># mark outputs so that we can bind our own allocated memory in runtime</span>
<span class="c1"># (see generation.py)</span>
<span class="n">uc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;gemm_allreduce_uc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">mc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;gemm_allreduce_mc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">ipc_output</span><span class="o">.</span><span class="n">mark_output</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;gemm_allreduce_ipc_out_</span><span class="si">{</span><span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">gemm_allreduce</span><span class="o">.</span><span class="n">layer_idx</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">uc_output</span></div>
<div class="viewcode-block" id="bert_attention">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.bert_attention">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">bert_attention</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">input_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">head_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">relative_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">max_input_length</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sage_attn</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">sage_attn_q_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">sage_attn_k_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">sage_attn_v_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">cp_group</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs the multi-head attention in BERT.</span>
<span class="sd"> The multi-head attention (MHA) is the sequence of a batched matmul, a</span>
<span class="sd"> softmax and a batched matmul as described in</span>
<span class="sd"> https://arxiv.org/abs/1706.03762. That function adds an operation that</span>
<span class="sd"> performs those computations using a single GPU kernel.</span>
<span class="sd"> The input tensor contains the Q, K and V elements. It is a 2D tensor and</span>
<span class="sd"> its shape is &#39;[sum_of_tokens, 3*hidden_dim]&#39; where the &#39;sum_of_tokens&#39; is</span>
<span class="sd"> the sum of the sequence lengths in the batch.</span>
<span class="sd"> In MHA, the output of the Q*K^T product is scaled by a constant value that</span>
<span class="sd"> is computed as:</span>
<span class="sd"> 1.f / (q_scaling * sqrt(head_size)).</span>
<span class="sd"> That &#39;q_scaling&#39; constant is the last argument of that function.</span>
<span class="sd"> That layer is implemented using a plugin (see bertAttentionPlugin).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> tensor : Tensor</span>
<span class="sd"> The QKV input tensor.</span>
<span class="sd"> input_lengths : Tensor</span>
<span class="sd"> The length of each sequence. It is a 1D tensor of size &#39;batch_size&#39;.</span>
<span class="sd"> num_heads : int</span>
<span class="sd"> The number of heads.</span>
<span class="sd"> head_size : int</span>
<span class="sd"> The size of each head.</span>
<span class="sd"> q_scaling : float</span>
<span class="sd"> The factor to compute the scaling factor to scale the output of the</span>
<span class="sd"> &#39;Q*K^T&#39; product.</span>
<span class="sd"> relative_attention: bool = False</span>
<span class="sd"> If enable relative attention.</span>
<span class="sd"> relative_attention_bias: Tensor = None</span>
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
<span class="sd"> max_distance: int = 0</span>
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
<span class="sd"> max_input_length: Tensor = None</span>
<span class="sd"> The maximum input sequence length represented by Tensor shape. Requires for remove_input_padding to pre-define plugin workspace size.</span>
<span class="sd"> sage_attn: bool = False</span>
<span class="sd"> SageAttention is a 8-bit implementation of attention kernel. It&#39;s input q, k, v and output datatypes are 16-bit. It performance dynamic quantization for q, k, v</span>
<span class="sd"> tensor every time before attention. https://github.com/thu-ml/SageAttention</span>
<span class="sd"> sage_attn_q_quant_size: int = 0</span>
<span class="sd"> dynamic quant block size along sequence dimension of q tensor. Each quant block will share one scale.</span>
<span class="sd"> sage_attn_k_quant_size: int = 0</span>
<span class="sd"> dynamic quant block size along sequence dimension of k tensor. Each quant block will share one scale.</span>
<span class="sd"> sage_attn_v_quant_size: int = 0</span>
<span class="sd"> dynamic quant block size along sequence dimension of v tensor. Each quant block will share one scale.</span>
<span class="sd"> cp_group: list[int] = None</span>
<span class="sd"> The communication group for context parallel</span>
<span class="sd"> cp_size: int = 1</span>
<span class="sd"> The communication size for context parallel</span>
<span class="sd"> cp_rank: int = 0</span>
<span class="sd"> The communication rank for context parallel</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;BertAttention&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_heads&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;head_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">head_size</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;q_scaling&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;context_fmha_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">bert_attention_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">do_relative_attention</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;do_relative_attention&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">relative_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_distance&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">sage_attn</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;sage_attn&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">sage_attn</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">sage_attn_q_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;sage_attn_q_block_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_q_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">sage_attn_k_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;sage_attn_k_block_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_k_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">sage_attn_v_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;sage_attn_v_block_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">sage_attn_v_block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># transpose q,k,v inside qkv to make kv contiguous, which is required by ring attention</span>
<span class="c1"># (b, s, 3d)</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">chunk</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">bs</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">seq_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># (b, s, d) -&gt; (b, s, 2d) -&gt; (2b, s, d)</span>
<span class="n">kv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">((</span><span class="mi">2</span> <span class="o">*</span> <span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])))</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">((</span><span class="n">query</span><span class="p">,</span> <span class="n">kv</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">((</span><span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)))</span>
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">cp_group</span> <span class="ow">or</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_group&quot;</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">nheads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
<span class="n">do_relative_attention</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span> <span class="n">remove_padding</span><span class="p">,</span> <span class="n">sage_attn</span><span class="p">,</span>
<span class="n">sage_attn_q_block_size</span><span class="p">,</span> <span class="n">sage_attn_k_block_size</span><span class="p">,</span> <span class="n">sage_attn_v_block_size</span><span class="p">,</span>
<span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">,</span> <span class="n">cp_group</span>
<span class="p">])</span>
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;padding_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">tensor</span><span class="p">,</span> <span class="n">input_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">max_input_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># for remove padding mode</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">max_input_length</span><span class="p">]</span>
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># for relative attention mode</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">&quot;padding_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected 1&quot;</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">RopeEmbeddingUtils</span><span class="p">:</span>
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_llama3_scaling">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_llama3_scaling">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="c1"># ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L298</span>
<span class="k">def</span><span class="w"> </span><span class="nf">apply_llama3_scaling</span><span class="p">(</span><span class="n">inv_freqs</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
<span class="n">scale_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;factor&quot;</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">)</span>
<span class="n">low_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;low_freq_factor&quot;</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
<span class="n">high_freq_factor</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;high_freq_factor&quot;</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">)</span>
<span class="n">old_context_len</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s2">&quot;original_max_position_embeddings&quot;</span><span class="p">,</span> <span class="mi">8192</span><span class="p">)</span>
<span class="n">low_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">low_freq_factor</span>
<span class="n">high_freq_wavelen</span> <span class="o">=</span> <span class="n">old_context_len</span> <span class="o">/</span> <span class="n">high_freq_factor</span>
<span class="n">new_inv_freqs</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">inv_freq</span> <span class="ow">in</span> <span class="n">inv_freqs</span><span class="p">:</span>
<span class="n">wavelen</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="n">inv_freq</span>
<span class="k">if</span> <span class="n">wavelen</span> <span class="o">&lt;</span> <span class="n">high_freq_wavelen</span><span class="p">:</span>
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">wavelen</span> <span class="o">&gt;</span> <span class="n">low_freq_wavelen</span><span class="p">:</span>
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">low_freq_wavelen</span> <span class="o">!=</span> <span class="n">high_freq_wavelen</span>
<span class="n">smooth</span> <span class="o">=</span> <span class="p">(</span><span class="n">old_context_len</span> <span class="o">/</span> <span class="n">wavelen</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span>
<span class="n">high_freq_factor</span> <span class="o">-</span> <span class="n">low_freq_factor</span><span class="p">)</span>
<span class="n">new_inv_freqs</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">smooth</span><span class="p">)</span> <span class="o">*</span> <span class="n">inv_freq</span> <span class="o">/</span> <span class="n">scale_factor</span> <span class="o">+</span>
<span class="n">smooth</span> <span class="o">*</span> <span class="n">inv_freq</span><span class="p">)</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">new_inv_freqs</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">inv_freqs</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions</span><span class="p">(</span><span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i , j -&gt; i j&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">inv_freq</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_for_attention_plugin</span><span class="p">(</span>
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
<span class="c1"># Other scaling configs that only used by certain scaling types.</span>
<span class="n">rope_scaling_config</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">llama3</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">rope_scaling_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;rotary_scaling config must be provided.&quot;</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_llama3_scaling</span><span class="p">(</span>
<span class="n">inv_freq</span><span class="p">,</span> <span class="n">rope_scaling_config</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">dynamic</span><span class="p">:</span>
<span class="c1"># Make sure scaling_alpha exists in rope_scaling</span>
<span class="c1"># Ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8/blob/main/modeling_hunyuan.py#L346</span>
<span class="k">assert</span> <span class="n">rope_scaling_config</span><span class="p">[</span>
<span class="s2">&quot;alpha&quot;</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;rope_scaling_config.alpha must be provided.&quot;</span>
<span class="n">scaling_alpha</span> <span class="o">=</span> <span class="n">rope_scaling_config</span><span class="p">[</span><span class="s2">&quot;alpha&quot;</span><span class="p">]</span>
<span class="n">adjusted_base</span> <span class="o">=</span> <span class="n">theta</span> <span class="o">*</span> <span class="p">(</span><span class="n">scaling_alpha</span><span class="o">**</span><span class="p">(</span><span class="n">dim</span> <span class="o">/</span> <span class="p">(</span><span class="n">dim</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)))</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">adjusted_base</span><span class="o">**</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span>
<span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i , j -&gt; i j&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">inv_freq</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1">#np.cos(sinusoid_inp).shape = (32768, 64, 1)</span>
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_for_cogvlm_attention_plugin">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_for_cogvlm_attention_plugin</span><span class="p">(</span>
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1225</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
<span class="k">if</span> <span class="n">scale_type</span> <span class="o">==</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">linear</span><span class="p">:</span>
<span class="n">scale</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">scale</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">position_id</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">vision_start</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">vision_start</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">num_pos</span> <span class="o">-</span> <span class="p">(</span><span class="n">vision_length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="p">])</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i , j -&gt; i j&quot;</span><span class="p">,</span>
<span class="n">position_id</span><span class="p">,</span>
<span class="n">inv_freq</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">inv_freq</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_long_rope_for_attention_plugin</span><span class="p">(</span>
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">num_orig_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">scaling_short_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scaling_long_factors</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">short_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">long_mscale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_calc_mscale</span><span class="p">(</span><span class="n">scale</span><span class="p">):</span>
<span class="k">if</span> <span class="n">scale</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="k">return</span> <span class="mf">1.0</span>
<span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">num_orig_pos</span><span class="p">))</span>
<span class="k">if</span> <span class="n">short_mscale</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">short_mscale</span> <span class="o">=</span> <span class="n">_calc_mscale</span><span class="p">(</span><span class="n">num_pos</span> <span class="o">/</span> <span class="n">num_orig_pos</span><span class="p">)</span>
<span class="n">long_mscale</span> <span class="o">=</span> <span class="n">short_mscale</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_compute_sinusoidal_positions</span><span class="p">(</span><span class="n">scale_factors</span><span class="p">,</span> <span class="n">is_short</span><span class="p">,</span>
<span class="n">for_attention_plugin</span><span class="p">):</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">scale_factors</span> <span class="o">*</span>
<span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i , j -&gt; i j&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">),</span>
<span class="n">inv_freq</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">concat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">mscale</span> <span class="o">=</span> <span class="n">short_mscale</span> <span class="k">if</span> <span class="n">is_short</span> <span class="k">else</span> <span class="n">long_mscale</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">concat</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="n">mscale</span>
<span class="c1"># gpt attention plugins also need inv_freq.</span>
<span class="k">if</span> <span class="n">for_attention_plugin</span><span class="p">:</span>
<span class="k">return</span> <span class="n">inv_freq</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">concat</span>
<span class="k">return</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
<span class="n">scaling_long_factors</span><span class="p">,</span>
<span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
<span class="n">scaling_short_factors</span><span class="p">,</span> <span class="kc">True</span><span class="p">,</span>
<span class="kc">True</span><span class="p">),</span> <span class="n">_compute_sinusoidal_positions</span><span class="p">(</span>
<span class="n">scaling_long_factors</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span> <span class="n">short_mscale</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_long_rope">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_long_rope</span><span class="p">(</span>
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">theta</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">original_max_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">short_factor</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
<span class="n">long_factor</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span>
<span class="n">max_seq_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="n">short_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">short_factor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">long_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">long_factor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">theta</span><span class="o">**</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">))</span>
<span class="n">t_pos</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">([</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">original_max_pos</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c1"># Choose proper freqs based on max_seq_len.</span>
<span class="n">factor</span> <span class="o">=</span> <span class="n">long_factor</span> <span class="k">if</span> <span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">max_seq_len</span> <span class="o">&gt;</span> <span class="n">original_max_pos</span> <span class="k">else</span> <span class="n">short_factor</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">inv_freq</span> <span class="o">/</span> <span class="n">factor</span>
<span class="n">freqs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i,j-&gt;ij&quot;</span><span class="p">,</span> <span class="n">t_pos</span><span class="p">,</span> <span class="n">inv_freq</span><span class="p">)</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">freqs</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)[</span><span class="o">...</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span>
<span class="c1"># Apply scaling</span>
<span class="n">scale</span> <span class="o">=</span> <span class="n">num_pos</span> <span class="o">/</span> <span class="n">original_max_pos</span>
<span class="k">if</span> <span class="n">scale</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">:</span>
<span class="n">scaling_factor</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">scaling_factor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span>
<span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">original_max_pos</span><span class="p">))</span>
<span class="c1"># fuse cos/sin into float2 (cos, sin).</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)</span> <span class="o">*</span> <span class="n">scaling_factor</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">sinusoid_inp</span><span class="p">)</span> <span class="o">*</span> <span class="n">scaling_factor</span><span class="p">),</span>
<span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_fake_weight">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_fake_weight">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_fake_weight</span><span class="p">(</span><span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">half</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span></div>
<span class="c1"># Note: When not using deepseek_yarn, make sure to set mscale_all_dim to 0.0.</span>
<div class="viewcode-block" id="RopeEmbeddingUtils.create_sinusoidal_positions_yarn">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.create_sinusoidal_positions_yarn">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_sinusoidal_positions_yarn</span><span class="p">(</span>
<span class="n">num_pos</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">base</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10000</span><span class="p">,</span>
<span class="n">scaling_factor</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">original_max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4096</span><span class="p">,</span>
<span class="n">beta_fast</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">beta_slow</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">mscale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">mscale_all_dim</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">duplicate_data</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">):</span>
<span class="c1"># Copy from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py</span>
<span class="c1"># Inverse dim formula to find dim based on number of rotations</span>
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_find_correction_dim</span><span class="p">(</span><span class="n">num_rotations</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">):</span>
<span class="k">return</span> <span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_position_embeddings</span> <span class="o">/</span>
<span class="p">(</span><span class="n">num_rotations</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">)))</span> <span class="o">/</span> <span class="p">(</span>
<span class="mi">2</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">base</span><span class="p">))</span>
<span class="c1"># Find dim range bounds based on rotations</span>
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_find_correction_range</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span> <span class="n">high_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">):</span>
<span class="n">low</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span>
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">low_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">))</span>
<span class="n">high</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span>
<span class="n">yarn_find_correction_dim</span><span class="p">(</span><span class="n">high_rot</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">))</span>
<span class="k">if</span> <span class="n">low</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">low</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">high</span> <span class="o">&gt;</span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">high</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="c1"># Clamp values just in case</span>
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_get_mscale</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span> <span class="n">mscale</span><span class="p">):</span>
<span class="k">if</span> <span class="n">scale</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="mf">1.0</span>
<span class="k">return</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">mscale</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">scale</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1.0</span>
<span class="k">def</span><span class="w"> </span><span class="nf">yarn_linear_ramp_mask</span><span class="p">(</span><span class="nb">min</span><span class="p">,</span> <span class="nb">max</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">min</span> <span class="o">==</span> <span class="nb">max</span><span class="p">:</span>
<span class="nb">max</span> <span class="o">+=</span> <span class="mf">0.001</span> <span class="c1"># Prevent singularity</span>
<span class="n">linear_func</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="nb">max</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span>
<span class="n">ramp_func</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="n">linear_func</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">ramp_func</span>
<span class="n">pos_freqs</span> <span class="o">=</span> <span class="n">base</span><span class="o">**</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span> <span class="o">/</span> <span class="n">dim</span><span class="p">)</span>
<span class="n">freq_extra</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">pos_freqs</span>
<span class="n">freq_inter</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="n">scaling_factor</span> <span class="o">*</span> <span class="n">pos_freqs</span><span class="p">)</span>
<span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="o">=</span> <span class="n">yarn_find_correction_range</span><span class="p">(</span>
<span class="n">beta_fast</span><span class="p">,</span>
<span class="n">beta_slow</span><span class="p">,</span>
<span class="n">dim</span><span class="p">,</span>
<span class="n">base</span><span class="p">,</span>
<span class="n">original_max_position_embeddings</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">inv_freq_mask</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">yarn_linear_ramp_mask</span><span class="p">(</span><span class="n">low</span><span class="p">,</span> <span class="n">high</span><span class="p">,</span> <span class="n">dim</span> <span class="o">//</span> <span class="mi">2</span><span class="p">))</span>
<span class="n">inv_freq</span> <span class="o">=</span> <span class="n">freq_inter</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">inv_freq_mask</span><span class="p">)</span> <span class="o">+</span> <span class="n">freq_extra</span> <span class="o">*</span> <span class="n">inv_freq_mask</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">sinusoid_inp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;i,j -&gt; ij&quot;</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">inv_freq</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">_mscale</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span>
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale</span><span class="p">)</span> <span class="o">/</span>
<span class="n">yarn_get_mscale</span><span class="p">(</span><span class="n">scaling_factor</span><span class="p">,</span> <span class="n">mscale_all_dim</span><span class="p">))</span>
<span class="k">if</span> <span class="n">duplicate_data</span><span class="p">:</span>
<span class="n">emb</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">sinusoid_inp</span><span class="p">,</span> <span class="n">sinusoid_inp</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">2</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">emb</span> <span class="o">=</span> <span class="n">sinusoid_inp</span>
<span class="n">concat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">emb</span><span class="p">)</span> <span class="o">*</span> <span class="n">_mscale</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">inv_freq</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">concat</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_every_two">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_every_two">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">rotate_every_two</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="p">])</span>
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="n">x1</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">x2</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">4</span><span class="p">)</span>
<span class="k">return</span> <span class="n">view</span><span class="p">(</span>
<span class="n">x</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span><span class="p">]))</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.rotate_half">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.rotate_half">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">rotate_half</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="c1"># [bs, num_attention_kv_heads, seqlen, attention_head_size]</span>
<span class="k">assert</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">4</span>
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="p">])</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">x1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">x2</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">zero</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">dtype</span><span class="p">))))</span>
<span class="n">x2</span> <span class="o">=</span> <span class="n">zero</span> <span class="o">-</span> <span class="n">x2</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">x2</span><span class="p">,</span> <span class="n">x1</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb</span><span class="p">(</span>
<span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">position_embedding</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">pos_emb_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">rotate_func</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span> <span class="ow">or</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">long_rope</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">sin</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">cos</span><span class="p">,</span> <span class="n">cos</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span>
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gptj</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
<span class="n">cos</span><span class="p">,</span> <span class="n">sin</span> <span class="o">=</span> <span class="n">position_embedding</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">sin</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">sin</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">cos</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">cos</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">rotate_func</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_every_two</span>
<span class="k">elif</span> <span class="n">pos_emb_type</span> <span class="o">==</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span>
<span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span> <span class="o">=</span> <span class="n">position_embedding</span>
<span class="n">shape_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span>
<span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="p">])</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
<span class="n">x_part0</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">shape_tensor</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">x_part1</span> <span class="o">=</span> <span class="nb">slice</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span> <span class="n">shape_tensor</span><span class="p">,</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">y_part0</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part0</span> <span class="o">*</span>
<span class="n">cos0</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part0</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin0</span><span class="p">)</span>
<span class="n">y_part1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_part1</span> <span class="o">*</span>
<span class="n">cos1</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">rotate_half</span><span class="p">(</span><span class="n">x_part1</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin1</span><span class="p">)</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">y_part0</span><span class="p">,</span> <span class="n">y_part1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;The PositionEmbeddingType is not RoPE&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">tensor</span> <span class="o">*</span> <span class="n">cos</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="n">rotate_func</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span> <span class="o">*</span> <span class="n">sin</span><span class="p">)</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_chatglm">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb_chatglm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="n">half_head_size</span> <span class="o">=</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">seqlen</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="mi">3</span><span class="p">,</span>
<span class="n">attention_head_size</span><span class="p">,</span>
<span class="p">]))</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">seqlen</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">attention_head_size</span><span class="p">,</span>
<span class="p">])</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">embedding_weight</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">[</span>
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
<span class="n">embedding_weight</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span>
<span class="p">],</span>
<span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">embedding_weight</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">trt_dtype_to_np</span><span class="p">(</span><span class="n">query</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span>
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">embedding_weight</span><span class="p">)</span>
<span class="n">position_embedding</span><span class="p">,</span> <span class="n">block_embedding</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span>
<span class="n">position_embedding</span><span class="p">,</span>
<span class="mi">1</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">sin0</span><span class="p">,</span> <span class="n">cos0</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">sin1</span><span class="p">,</span> <span class="n">cos1</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">block_embedding</span><span class="p">,</span> <span class="n">half_head_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">seqlen</span><span class="p">,</span>
<span class="mi">1</span><span class="p">,</span>
<span class="n">half_head_size</span><span class="p">,</span>
<span class="p">])</span>
<span class="n">position_embedding</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="p">[</span><span class="n">cos0</span><span class="p">,</span> <span class="n">cos1</span><span class="p">,</span> <span class="n">sin0</span><span class="p">,</span> <span class="n">sin1</span><span class="p">]</span>
<span class="p">]</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">query</span><span class="p">,</span>
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">key</span><span class="p">,</span>
<span class="n">position_embedding</span><span class="o">=</span><span class="n">position_embedding</span><span class="p">,</span>
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">chatglm</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">qkv</span></div>
<div class="viewcode-block" id="RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.RopeEmbeddingUtils.apply_rotary_pos_emb_cogvlm">[docs]</a>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">apply_rotary_pos_emb_cogvlm</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">position_embedding</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">,</span>
<span class="n">rotary_embedding_scale</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">qkv</span>
<span class="n">input_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">seqlen</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">remove_input_padding</span> <span class="k">else</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">qkv</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">seqlen</span><span class="p">,</span>
<span class="mi">3</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">attention_head_size</span><span class="p">,</span>
<span class="p">]))</span>
<span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">q_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">batch_size</span><span class="p">,</span>
<span class="n">seqlen</span><span class="p">,</span>
<span class="n">num_attention_heads</span><span class="p">,</span>
<span class="n">attention_head_size</span><span class="p">,</span>
<span class="p">])</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">q_shape</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">create_sinusoidal_positions</span><span class="p">(</span>
<span class="n">max_position_embeddings</span><span class="p">,</span> <span class="n">attention_head_size</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">embedding_weight</span> <span class="o">/=</span> <span class="n">rotary_embedding_scale</span> <span class="c1"># [max_position_embeddings, attention_head_size]</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># [1, seqlen]</span>
<span class="n">embedding_weight</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># float32</span>
<span class="n">position_embedding</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span>
<span class="n">position_embedding</span><span class="p">,</span>
<span class="n">embedding_weight</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size]</span>
<span class="n">sin</span><span class="p">,</span> <span class="n">cos</span> <span class="o">=</span> <span class="n">split</span><span class="p">(</span><span class="n">position_embedding</span><span class="p">,</span> <span class="n">attention_head_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [1, seqlen, attention_head_size//2]</span>
<span class="n">input_dtype</span> <span class="o">=</span> <span class="n">query</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">fp32_query</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_query</span><span class="p">,</span>
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
<span class="n">fp32_key</span> <span class="o">=</span> <span class="n">RopeEmbeddingUtils</span><span class="o">.</span><span class="n">apply_rotary_pos_emb</span><span class="p">(</span>
<span class="n">tensor</span><span class="o">=</span><span class="n">fp32_key</span><span class="p">,</span>
<span class="n">position_embedding</span><span class="o">=</span><span class="p">[</span><span class="n">cos</span><span class="p">,</span> <span class="n">sin</span><span class="p">],</span>
<span class="n">pos_emb_type</span><span class="o">=</span><span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">rope_gpt_neox</span><span class="p">)</span>
<span class="n">query</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_query</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
<span class="n">key</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_key</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">query</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="n">key</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="n">value</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">),</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="n">qkv</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">qkv</span></div>
</div>
<div class="viewcode-block" id="gpt_attention">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gpt_attention">[docs]</a>
<span class="nd">@gw</span><span class="o">.</span><span class="n">record_signature</span>
<span class="k">def</span><span class="w"> </span><span class="nf">gpt_attention</span><span class="p">(</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">qkv</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">past_key_value</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">attention_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">attention_packed_mask</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">sequence_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_past_key_value_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_max_attention_window_sizes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_sink_token_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">cache_indirection</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">num_kv_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">hidden_size_per_head</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">q_scaling</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">attn_logit_softcapping_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">rotary_embedding_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">rotary_embedding_base</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">10000.0</span><span class="p">,</span>
<span class="n">rotary_embedding_scale_type</span><span class="p">:</span> <span class="n">RotaryScalingType</span> <span class="o">=</span> <span class="n">RotaryScalingType</span><span class="o">.</span><span class="n">none</span><span class="p">,</span>
<span class="n">rotary_embedding_short_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">rotary_embedding_long_m_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">rotary_embedding_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">rotary_embedding_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
<span class="n">rotary_embedding_original_max_positions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">,</span>
<span class="n">position_embedding_type</span><span class="p">:</span> <span class="n">PositionEmbeddingType</span> <span class="o">=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span>
<span class="n">learned_absolute</span><span class="p">,</span>
<span class="n">rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_quant_orig_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">attention_output_orig_quant_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">attention_output_sf_scale</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_cache_quant_mode</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">QuantModeWrapper</span><span class="p">,</span> <span class="n">QuantMode</span><span class="p">]</span> <span class="o">=</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">max_context_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mask_type</span><span class="p">:</span> <span class="n">AttentionMaskType</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">causal</span><span class="p">,</span>
<span class="n">block_sparse_block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
<span class="n">block_sparse_homo_head_pattern</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">block_sparse_num_local_blocks</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,</span>
<span class="n">block_sparse_vertical_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
<span class="n">alibi_slopes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">vision_start</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
<span class="n">vision_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
<span class="n">kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_kv_cache_block_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_kv_cache_pool_pointers</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_kv_cache_pool_mapping</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">do_cross_attention</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">cross_kv</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">cross_kv_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">encoder_input_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for cross attention</span>
<span class="n">relative_attention_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for relative attention</span>
<span class="n">logn_scaling</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for logn scaling</span>
<span class="n">max_distance</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="c1"># for relative attention</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
<span class="n">qkv_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">use_cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">spec_decoding_is_generation_length_variable</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">spec_decoding_max_generation_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">spec_decoding_generation_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">spec_decoding_position_offsets</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">spec_decoding_packed_mask</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">spec_decoding_use</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">long_rope_rotary_inv_freq</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">long_rope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mrope_rotary_cos_sin</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">mrope_position_deltas</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_runtime_perf_knobs</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_context_progress</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">is_mla_enabled_flag</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">q_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">kv_lora_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">qk_nope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">qk_rope_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">v_head_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">q_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">kv_b_proj</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">k_b_proj_trans</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">skip_attn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">cp_group</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">num_kv_heads_origin</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation that performs the multi-head attention in GPT-like models.</span>
<span class="sd"> The signature of the function will change in the future release - we are in</span>
<span class="sd"> the process of simplifying the API. The current version is still</span>
<span class="sd"> work-in-progress! The following API is provided with hints regarding the</span>
<span class="sd"> arguments that are likely to be removed or merged with others in the future</span>
<span class="sd"> release.</span>
<span class="sd"> See docs/source/advanced/gpt-attention.md for the documentation of that function.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> qkv: Tensor (On GPU)</span>
<span class="sd"> The input QKV tensor. Its shape is [batch_beam_size, max_seqlen, qkv_dim] in padded mode and [num_tokens, qkv_dim] in</span>
<span class="sd"> packed mode. Where qkv_dim depends on using MQA, GQA, or MHA. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> past_key_value: Tensor (On GPU)</span>
<span class="sd"> The tensor that stores KV cache data. Its shape is</span>
<span class="sd"> [max_batch_size * max_beam_width, 2, num_kv_heads, max_seqlen, hidden_dim_per_head]</span>
<span class="sd"> in contiguous mode and</span>
<span class="sd"> [max_blocks, 2, num_kv_heads, num_tokens_per_block, hidden_dim_per_head]</span>
<span class="sd"> in paged mode. See KV Cache in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> attention_mask: Tensor (On GPU)</span>
<span class="sd"> The tensor that stores the attention mask for unfused MHA or MMHA.</span>
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen].</span>
<span class="sd"> attention_packed_mask: Tensor (On GPU)</span>
<span class="sd"> The tensor that stores the packed custom mask for fmha.</span>
<span class="sd"> Its shape is [num_tokens, max_kv_seqlen / 32], where each bit represents one mask position.</span>
<span class="sd"> sequence_lengths: Tensor (On GPU)</span>
<span class="sd"> The tensor that stores the length of each sequence. Its shape is</span>
<span class="sd"> [batch_size]. See QKV Input in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> host_past_key_value_lengths: Tensor (On CPU)</span>
<span class="sd"> An INT32 tensor of shape [batch_size],</span>
<span class="sd"> host_max_attention_window_sizes: Tensor (On CPU)</span>
<span class="sd"> An INT32 tensor of shape [1].</span>
<span class="sd"> by default, the max_attention_window_size is determined by the shape of cache_indir_table.</span>
<span class="sd"> And we support independent max_attention_window_size for each layer.</span>
<span class="sd"> This controls the sliding-window-attention kv-cache features.</span>
<span class="sd"> context_lengths: Tensor (On GPU)</span>
<span class="sd"> The tensor that stores the context-phase sequence length of each request. Its shape</span>
<span class="sd"> is [batch_size]. See QKV Input in doc/functional.py,</span>
<span class="sd"> cache_indirection: Tensor (On GPU)</span>
<span class="sd"> The tensor to reconstruct the paths when using beam-search. Its</span>
<span class="sd"> shape is [batch_size, beam_width, max_seqlen]. See Beam-Search in</span>
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> layer_idx: int</span>
<span class="sd"> The index of this attention layer, used to access kv_cache_block_offsets,</span>
<span class="sd"> num_heads: int</span>
<span class="sd"> The number of heads,</span>
<span class="sd"> num_kv_heads: int</span>
<span class="sd"> The number of KV heads, generic to handle MHA/MQA/GQA,</span>
<span class="sd"> hidden_size_per_head: int</span>
<span class="sd"> The hidden size per head,</span>
<span class="sd"> q_scaling: float</span>
<span class="sd"> The value used to compute the scaling factor applied to the output</span>
<span class="sd"> of the Q*K^T product. See Scaling Factors in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> attn_logit_softcapping_scale: float</span>
<span class="sd"> The scale * tanh(value / scale) used to compute the scaling factor applied to the output</span>
<span class="sd"> of the Q*K^T product.</span>
<span class="sd"> rotary_embedding_dim: int</span>
<span class="sd"> The dimension to compute RoPE. Use 0 when position_embedding_type is not RoPE.</span>
<span class="sd"> rotary_embedding_base: float</span>
<span class="sd"> The theta value to use for RoPE. Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> rotary_embedding_scale_type: RotaryScalingType</span>
<span class="sd"> The scaling type of RoPE. Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> Possible rotary scaling type:</span>
<span class="sd"> * RotaryScalingType.none</span>
<span class="sd"> * RotaryScalingType.linear</span>
<span class="sd"> * RotaryScalingType.dynamic</span>
<span class="sd"> * RotaryScalingType.longrope</span>
<span class="sd"> * RotaryScalingType.llama3</span>
<span class="sd"> rotary_embedding_scale: float</span>
<span class="sd"> The scale value to use for linear/dynamic scaling in RoPE.</span>
<span class="sd"> Ignored when position_embedding_type is not RoPE.</span>
<span class="sd"> Must be set to 1 (default) if rotary_embedding_scale_type is `none`.</span>
<span class="sd"> rotary_inv_freq: float Tensor</span>
<span class="sd"> The rotary inv freq with shape [head_size / 2].</span>
<span class="sd"> rotary_cos_sin: float2(cos/sin) Tensor</span>
<span class="sd"> The rotary cos/sin cache, which will be reused among different requests.</span>
<span class="sd"> It is taken as constant tensor.</span>
<span class="sd"> rotary_embedding_max_positions: int</span>
<span class="sd"> Needed only for `dynamic` RoPE scaling. Ignored otherwise.</span>
<span class="sd"> position_embedding_type: PositionEmbeddingType</span>
<span class="sd"> The position embedding type:</span>
<span class="sd"> * PositionEmbeddingType.learned_absolute</span>
<span class="sd"> * PositionEmbeddingType.relative</span>
<span class="sd"> * PositionEmbeddingType.rope_gptj</span>
<span class="sd"> * PositionEmbeddingType.rope_gpt_neox</span>
<span class="sd"> * PositionEmbeddingType.alibi</span>
<span class="sd"> * PositionEmbeddingType.alibi_with_scale</span>
<span class="sd"> kv_orig_quant_scale: Tensor</span>
<span class="sd"> The tensor to store the scaling factor for quantization to INT8/FP8</span>
<span class="sd"> in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache in</span>
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> kv_quant_orig_scale: Tensor</span>
<span class="sd"> The tensor to store the scaling factor for dequantization from</span>
<span class="sd"> INT8/FP8 in the KV cache. Its shape is [1]. See INT8/FP8 KV Cache</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> attention_output_orig_quant_scale: Tensor</span>
<span class="sd"> The tensor to store the scaling factor for quantization to FP8</span>
<span class="sd"> in the KV cache. Its shape is [1].</span>
<span class="sd"> kv_cache_quant_mode: QuantMode (int flags)</span>
<span class="sd"> Do we enable the INT8 or FP8 KV cache?</span>
<span class="sd"> max_context_length: int32_t</span>
<span class="sd"> The length of the longest input sequence. See QKV Input in</span>
<span class="sd"> docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> mask_type: int = 1</span>
<span class="sd"> The type of mask:</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.padding for BERT,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.causal for GPT,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.sliding_window_causal for GPT,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectional for ChatGLM-6B,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.bidirectionalglm for GLM-10B,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.blocksparse for Phi-3-small,</span>
<span class="sd"> * tensorrt_llm.layers.AttentionMaskType.custom_mask for any models.</span>
<span class="sd"> block_sparse_block_size: int</span>
<span class="sd"> Block size in block sparse attention</span>
<span class="sd"> block_sparse_homo_head_pattern: bool</span>
<span class="sd"> Do all attention heads share same vertical stride pattern?</span>
<span class="sd"> block_sparse_num_local_blocks: int</span>
<span class="sd"> Number of active blocks near diagonal</span>
<span class="sd"> block_sparse_vertical_stride: int</span>
<span class="sd"> Stride of active blocks in vertical dimension</span>
<span class="sd"> alibi_slopes: Tensor</span>
<span class="sd"> The ALiBi slopes. The ALiBi bias is computed on-the-fly in the kernel</span>
<span class="sd"> when possible,</span>
<span class="sd"> tp_size: int</span>
<span class="sd"> The number of processes/GPUs when tensor parallelism is activated,</span>
<span class="sd"> tp_rank: int</span>
<span class="sd"> The rank of that process (when running tensor parallelism),</span>
<span class="sd"> kv_cache_block_offsets:</span>
<span class="sd"> The tensor of block offsets for the KV cache. Its shape is</span>
<span class="sd"> [num_layers, max_batch_size, max_beam_width, 2, max_blocks_per_sequence * 2],</span>
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
<span class="sd"> host_kv_cache_block_offsets:</span>
<span class="sd"> The same as kv_cache_block_offsets, but on cpu,</span>
<span class="sd"> host_kv_cache_pool_pointers:</span>
<span class="sd"> The tensor of pool pointers for the KV cache. Its shape is [num_layers, 2],</span>
<span class="sd"> See KV cache section in docs/source/advanced/gpt-attention.md, on gpu,</span>
<span class="sd"> host_kv_cache_pool_mapping:</span>
<span class="sd"> The tensor of pool mapping for the different memory pools. Its shape is [num_layers,2] - for each layer, the index of the pool, and the index of the layer within the pool,</span>
<span class="sd"> do_cross_attention: bool = False</span>
<span class="sd"> Do we use this as cross attention instead of self attention,</span>
<span class="sd"> cross_kv: Tensor = None</span>
<span class="sd"> The KV tensor of encoder output hidden states. Its shape is [batch_size, max_seqlen, 2 * kvHeadNum * headSize] in padded mode and [1, num_tokens, 2 * kvHeadNum * headSize] in</span>
<span class="sd"> packed mode,</span>
<span class="sd"> cross_kv_length: Tensor = None</span>
<span class="sd"> The length of the longest encoder output sequence,</span>
<span class="sd"> encoder_input_lengths: Tensor</span>
<span class="sd"> The tensor that stores the length of each encoder input sequence. Its shape is [batch_size],</span>
<span class="sd"> logn_scaling: Tensor = None</span>
<span class="sd"> The logn scaling tensor [max_position_embedding_len], which is applied to q in order to help extrapolation</span>
<span class="sd"> relative_attention_bias: Tensor = None</span>
<span class="sd"> The relative attention bias [num_heads, max_seq_len, max_seq_len], or The relative attention embedding table for implicit mode, [num_heads, num_buckets].</span>
<span class="sd"> max_distance: int = 0</span>
<span class="sd"> The maximum distance of relative position in attention, for implicit mode.</span>
<span class="sd"> Default value is 0, meaning to use the regular mode of relative attention bias.</span>
<span class="sd"> Implicit mode is only enabled when passing in non-zero positive max_distance value.</span>
<span class="sd"> See relative attention bias in docs/source/advanced/gpt-attention.md</span>
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> qkv_bias: Tensor = None,</span>
<span class="sd"> The qkv bias tensor.</span>
<span class="sd"> use_cache: bool = False</span>
<span class="sd"> Do we need to store kv cache ? not needed if there is no generation phase.</span>
<span class="sd"> spec_decoding_is_generation_length_variable: bool = False,</span>
<span class="sd"> Whether the generation lengths can be different for each sequence in a batch.</span>
<span class="sd"> For Medusa, this should be set False.</span>
<span class="sd"> For Redrafter, this should be set to True.</span>
<span class="sd"> spec_decoding_max_generation_length: int = 1,</span>
<span class="sd"> The maximum number of tokens possible in the generation phase per sequence.</span>
<span class="sd"> spec_decoding_generation_lengths: Tensor = None,</span>
<span class="sd"> The generation phase tokens&#39; lengths for each sequence.</span>
<span class="sd"> Shape: [batch_size]</span>
<span class="sd"> spec_decoding_position_offsets: Tensor = None,</span>
<span class="sd"> The speculative decoding tokens&#39;s position offsets (shared by all sequences).</span>
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1].</span>
<span class="sd"> spec_decoding_packed_mask: Tensor = None,</span>
<span class="sd"> The speculative decoding tokens&#39;s attention mask (packed into uint32_t bits).</span>
<span class="sd"> remove_input_padding is False:</span>
<span class="sd"> Shape: [batch_size, num_draft_tokens + 1, divUp(num_draft_tokens + 1, 32)].</span>
<span class="sd"> remove_input_padding is True:</span>
<span class="sd"> Shape: [sum(spec_decoding_generation_lengths), divUp(num_draft_tokens + 1, 32)].</span>
<span class="sd"> long_rope_rotary_inv_freq: float Tensor</span>
<span class="sd"> Additional rotary inv freq used for longer sequence lengths. Shape: [head_size / 2]</span>
<span class="sd"> long_rope_rotary_cos_sin: float2(cos/sin) Tensor</span>
<span class="sd"> Additional rotary cos/sin cache used for longer sequence lengths.</span>
<span class="sd"> is_mla_enable: bool = False</span>
<span class="sd"> Do we need to enable deepseekv2 mla?</span>
<span class="sd"> host_runtime_perf_knobs: Tensor = None,</span>
<span class="sd"> The runtime perf knobs bit mask, controls whether to use certain perf knob in the runtime.</span>
<span class="sd"> host_context_progress: Tensor = None,</span>
<span class="sd"> The structure used to track layer-wise progress in context phase.</span>
<span class="sd"> skip_attn: Tensor = None,</span>
<span class="sd"> A bool tensor on CPU. If it is true, don&#39;t run attention plugin, returning directly.</span>
<span class="sd"> num_kv_heads_origin: int</span>
<span class="sd"> The origin number of KV heads, without the process of TP</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_alibi</span><span class="p">())</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">mrope_rotary_cos_sin</span>
<span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">position_embedding_type</span><span class="o">.</span><span class="n">is_mrope</span><span class="p">())</span>
<span class="n">attn_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;GPTAttention&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attn_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">host_max_attention_window_sizes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">host_sink_token_length</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">paged_kv_cache_flag</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_kv_cache</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">is_unfuse_qkv_gemm</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span>
<span class="k">if</span> <span class="n">do_cross_attention</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">num_kv_heads_origin</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">num_kv_heads_origin</span> <span class="o">=</span> <span class="n">num_kv_heads</span>
<span class="n">unfuse_qkv_gemm</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;unfuse_qkv_gemm&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">is_unfuse_qkv_gemm</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">layer_idx</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;layer_idx&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">layer_idx</span><span class="p">,</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_heads&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">vision_start</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;vision_start&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_start</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">vision_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;vision_length&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">vision_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">num_kv_heads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_kv_heads&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">num_kv_heads_origin</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;num_kv_heads_origin&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_kv_heads_origin</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">head_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;head_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">hidden_size_per_head</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">unidirectional</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;unidirectional&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">q_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;q_scaling&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_scaling</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">attn_logit_softcapping_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;attn_logit_softcapping_scale&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">rotary_embedding_base</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_base&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_base</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_scale_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_scale_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">rotary_embedding_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_scale&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_short_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_short_m_scale&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_long_m_scale</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_long_m_scale&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">rotary_embedding_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_max_positions&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">rotary_embedding_original_max_positions</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;rotary_embedding_original_max_positions&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">position_embedding_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;position_embedding_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">position_embedding_type</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">context_fmha_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;context_fmha_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha_type</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">is_spec_decoding_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;is_spec_decoding_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">spec_decoding_is_generation_length_variable</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;spec_decoding_is_generation_length_variable&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">spec_decoding_is_generation_length_variable</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">spec_decoding_max_generation_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;spec_decoding_max_generation_length&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">is_mla_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;is_mla_enabled&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">is_mla_enabled_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">q_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;q_lora_rank&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">q_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">kv_lora_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;kv_lora_rank&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">qk_nope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;qk_nope_head_dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">qk_rope_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;qk_rope_head_dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">v_head_dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;v_head_dim&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">v_head_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gpt_attention_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="c1"># reset mask_type to custom_mask.</span>
<span class="k">if</span> <span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">):</span>
<span class="c1"># context fmha needs packed mask.</span>
<span class="k">assert</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">100</span><span class="p">:</span>
<span class="n">mask_type</span> <span class="o">=</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span>
<span class="n">mask_type_filed</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;mask_type&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">mask_type</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">block_sparse_block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;block_sparse_block_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_block_size</span><span class="p">],</span>
<span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">block_sparse_homo_head_pattern</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;block_sparse_homo_head_pattern&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">block_sparse_homo_head_pattern</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">block_sparse_num_local_blocks</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;block_sparse_num_local_blocks&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_num_local_blocks</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">block_sparse_vertical_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;block_sparse_vertical_stride&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">block_sparse_vertical_stride</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tp_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">tp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;tp_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">tp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">QuantModeWrapper</span><span class="p">):</span>
<span class="c1"># Now in TRT-LLM only use global kv_cache, so it&#39;s enough to get the first quant mode from list</span>
<span class="n">kv_cache_quant_mode</span> <span class="o">=</span> <span class="n">kv_cache_quant_mode</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">kv_cache_quant_mode_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;kv_cache_quant_mode&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kv_cache_quant_mode</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">paged_kv_cache</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;paged_kv_cache&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">paged_kv_cache_flag</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">tokens_per_block</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;tokens_per_block&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">max_context_length</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_context_length&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_context_length</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pos_shift_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;pos_shift_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">dense_context_fmha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;dense_context_fmha&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;qkv_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">qkv_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;qkv_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">do_cross_attention_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;do_cross_attention&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">do_cross_attention</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">max_distance</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_distance&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_distance</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">use_paged_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;use_paged_context_fmha&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">use_fp8_context_fmha_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;use_fp8_context_fmha&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">has_full_attention_mask_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;has_full_attention_mask&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">use_cache_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;use_cache&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">use_cache</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">fuse_fp4_quant</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fuse_fp4_quant</span>
<span class="n">fuse_fp4_quant_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;fuse_fp4_quant&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">fuse_fp4_quant</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">skip_attn_pf</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;skip_attn&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">cp_group</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
<span class="n">cp_group</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_group&quot;</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">use_logn_scaling</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;use_logn_scaling&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">use_logn_scaling</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">layer_idx</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">vision_start</span><span class="p">,</span> <span class="n">vision_length</span><span class="p">,</span> <span class="n">num_kv_heads</span><span class="p">,</span>
<span class="n">num_kv_heads_origin</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">unidirectional</span><span class="p">,</span> <span class="n">q_scaling</span><span class="p">,</span>
<span class="n">attn_logit_softcapping_scale</span><span class="p">,</span> <span class="n">position_embedding_type</span><span class="p">,</span>
<span class="n">rotary_embedding_dim</span><span class="p">,</span> <span class="n">rotary_embedding_base</span><span class="p">,</span>
<span class="n">rotary_embedding_scale_type</span><span class="p">,</span> <span class="n">rotary_embedding_scale</span><span class="p">,</span>
<span class="n">rotary_embedding_short_m_scale</span><span class="p">,</span> <span class="n">rotary_embedding_long_m_scale</span><span class="p">,</span>
<span class="n">rotary_embedding_max_positions</span><span class="p">,</span> <span class="n">rotary_embedding_original_max_positions</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">,</span> <span class="n">tp_rank</span><span class="p">,</span> <span class="n">unfuse_qkv_gemm</span><span class="p">,</span> <span class="n">context_fmha_type</span><span class="p">,</span>
<span class="n">kv_cache_quant_mode_field</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">mask_type_filed</span><span class="p">,</span>
<span class="n">block_sparse_block_size</span><span class="p">,</span> <span class="n">block_sparse_homo_head_pattern</span><span class="p">,</span>
<span class="n">block_sparse_num_local_blocks</span><span class="p">,</span> <span class="n">block_sparse_vertical_stride</span><span class="p">,</span>
<span class="n">paged_kv_cache</span><span class="p">,</span> <span class="n">tokens_per_block</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">max_context_length</span><span class="p">,</span>
<span class="n">qkv_bias_enabled</span><span class="p">,</span> <span class="n">do_cross_attention_field</span><span class="p">,</span> <span class="n">max_distance</span><span class="p">,</span>
<span class="n">pos_shift_enabled</span><span class="p">,</span> <span class="n">dense_context_fmha</span><span class="p">,</span> <span class="n">use_paged_context_fmha_field</span><span class="p">,</span>
<span class="n">use_fp8_context_fmha_field</span><span class="p">,</span> <span class="n">has_full_attention_mask_field</span><span class="p">,</span> <span class="n">use_cache_pf</span><span class="p">,</span>
<span class="n">is_spec_decoding_enabled</span><span class="p">,</span> <span class="n">spec_decoding_is_generation_length_variable</span><span class="p">,</span>
<span class="n">spec_decoding_max_generation_length</span><span class="p">,</span> <span class="n">is_mla_enabled</span><span class="p">,</span> <span class="n">q_lora_rank</span><span class="p">,</span>
<span class="n">kv_lora_rank</span><span class="p">,</span> <span class="n">qk_nope_head_dim</span><span class="p">,</span> <span class="n">qk_rope_head_dim</span><span class="p">,</span> <span class="n">v_head_dim</span><span class="p">,</span>
<span class="n">fuse_fp4_quant_pf</span><span class="p">,</span> <span class="n">skip_attn_pf</span><span class="p">,</span> <span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">,</span> <span class="n">cp_group</span><span class="p">,</span>
<span class="n">use_logn_scaling</span>
<span class="p">])</span>
<span class="n">attn_plug</span> <span class="o">=</span> <span class="n">attn_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;causal_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">attn_plug</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">qkv</span><span class="p">]</span> <span class="k">if</span> <span class="n">is_unfuse_qkv_gemm</span> <span class="k">else</span> <span class="p">[</span><span class="n">qkv</span><span class="p">]</span>
<span class="k">if</span> <span class="n">attention_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">mask_type</span> <span class="o">==</span> <span class="n">AttentionMaskType</span><span class="o">.</span><span class="n">custom_mask</span><span class="p">:</span>
<span class="c1"># useFullCustomMask</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_mask</span><span class="p">]</span>
<span class="k">if</span> <span class="n">attention_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">100</span><span class="p">:</span>
<span class="c1"># usePackedCustomMask</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_packed_mask</span><span class="p">]</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
<span class="n">sequence_length</span><span class="p">,</span>
<span class="n">host_past_key_value_lengths</span><span class="p">,</span>
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">cache_indirection</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
<span class="n">host_max_attention_window_sizes</span><span class="p">,</span>
<span class="n">host_sink_token_length</span><span class="p">,</span>
<span class="n">context_lengths</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">use_cache</span><span class="p">:</span>
<span class="k">if</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Paged kv cache is enabled, the kv_cache_block_offsets tensor shall not be None&quot;</span>
<span class="k">assert</span> <span class="n">host_kv_cache_block_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Paged kv cache is enabled, the host_kv_cache_block_offsets tensor shall not be None&quot;</span>
<span class="k">assert</span> <span class="n">host_kv_cache_pool_pointers</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Paged kv cache is enabled, the host_kv_cache_pool_pointers tensor shall not be None&quot;</span>
<span class="k">assert</span> <span class="n">host_kv_cache_pool_mapping</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;Paged kv cache is enabled, the host_kv_cache_pool_mapping tensor shall not be None&quot;</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
<span class="n">kv_cache_block_offsets</span><span class="p">,</span> <span class="n">host_kv_cache_block_offsets</span><span class="p">,</span>
<span class="n">host_kv_cache_pool_pointers</span><span class="p">,</span> <span class="n">host_kv_cache_pool_mapping</span>
<span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">past_key_value</span><span class="p">]</span>
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_kv_cache_quant</span><span class="p">():</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">kv_orig_quant_scale</span><span class="p">,</span> <span class="n">kv_quant_orig_scale</span><span class="p">]</span>
<span class="k">if</span> <span class="n">attention_output_orig_quant_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">,</span> <span class="s2">&quot;FP8 Context FMHA needs to be enabled&quot;</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_output_orig_quant_scale</span><span class="p">]</span>
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">attention_output_sf_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;attention_output_sf_scale must be provided when fuse_fp4_quant is enabled.&quot;</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">attention_output_sf_scale</span><span class="p">]</span>
<span class="k">if</span> <span class="n">rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_inv_freq</span><span class="p">]</span>
<span class="k">if</span> <span class="n">rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">rotary_cos_sin</span><span class="p">]</span>
<span class="k">if</span> <span class="n">alibi_slopes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">alibi_slopes</span><span class="p">]</span>
<span class="k">if</span> <span class="n">relative_attention_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">relative_attention_bias</span><span class="p">]</span>
<span class="k">if</span> <span class="n">do_cross_attention</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">cross_kv</span><span class="p">,</span> <span class="n">cross_kv_length</span><span class="p">,</span> <span class="n">encoder_input_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">qkv_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">qkv_bias</span><span class="p">]</span>
<span class="k">if</span> <span class="n">spec_decoding_packed_mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># add position_ids as well only if speculative decoding mode</span>
<span class="k">assert</span> <span class="n">spec_decoding_position_offsets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">spec_decoding_generation_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">spec_decoding_use</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
<span class="n">spec_decoding_generation_lengths</span><span class="p">,</span> <span class="n">spec_decoding_packed_mask</span><span class="p">,</span>
<span class="n">spec_decoding_position_offsets</span><span class="p">,</span> <span class="n">spec_decoding_use</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">long_rope_rotary_inv_freq</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">long_rope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">long_rope_rotary_inv_freq</span><span class="p">,</span> <span class="n">long_rope_rotary_cos_sin</span><span class="p">]</span>
<span class="k">if</span> <span class="n">mrope_rotary_cos_sin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">mrope_position_deltas</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span>
<span class="n">mrope_rotary_cos_sin</span><span class="p">,</span>
<span class="n">mrope_position_deltas</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">host_runtime_perf_knobs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_runtime_perf_knobs</span><span class="p">]</span>
<span class="k">if</span> <span class="n">host_context_progress</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_progress</span><span class="p">]</span>
<span class="k">if</span> <span class="n">is_mla_enabled_flag</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">q_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">kv_b_proj</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">k_b_proj_trans</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">q_b_proj</span><span class="p">,</span> <span class="n">kv_b_proj</span><span class="p">,</span> <span class="n">k_b_proj_trans</span><span class="p">]</span>
<span class="k">if</span> <span class="n">skip_attn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">skip_attn</span><span class="p">]</span>
<span class="k">if</span> <span class="n">logn_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">logn_scaling</span><span class="p">]</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">i</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Found None input for </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2"> th item in plugin inputs </span><span class="si">{</span><span class="n">plug_inputs</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">attn_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">attn_plg_creator</span><span class="p">,</span> <span class="s2">&quot;causal_attn&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">expected_outputs</span> <span class="o">=</span> <span class="mi">1</span>
<span class="c1"># The output scaling factor tensor.</span>
<span class="n">output_sf</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
<span class="n">output_sf</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">expected_outputs</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">present_key_value</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">use_cache</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="n">present_key_value</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span><span class="p">),</span>
<span class="n">layer</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">present_key_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">expected_outputs</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">assert</span> <span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span> <span class="o">==</span> <span class="n">expected_outputs</span><span class="p">,</span> \
<span class="sa">f</span><span class="s2">&quot;Plugin outputs number mismatch with expected, got </span><span class="si">{</span><span class="n">layer</span><span class="o">.</span><span class="n">num_outputs</span><span class="si">}</span><span class="s2">, expected </span><span class="si">{</span><span class="n">expected_outputs</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">if</span> <span class="n">kv_cache_quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">paged_kv_cache_flag</span><span class="p">:</span>
<span class="c1"># past key value</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="c1"># present key value</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_input</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">expected_outputs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_dynamic_range</span><span class="p">(</span><span class="o">-</span><span class="mi">127</span><span class="p">,</span> <span class="mi">127</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">output</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">fuse_fp4_quant</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">output_sf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">return</span> <span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">output_sf</span><span class="p">),</span> <span class="n">present_key_value</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_key_value</span></div>
<div class="viewcode-block" id="assertion">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.assertion">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">assertion</span><span class="p">(</span><span class="n">condition</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_assertion</span><span class="p">(</span><span class="n">condition</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">message</span><span class="p">)</span></div>
<div class="viewcode-block" id="layer_norm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.layer_norm">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">layer_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-05</span><span class="p">,</span>
<span class="n">use_diff_of_squares</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a layer-norm operation on a tensor.</span>
<span class="sd"> That operation applies the layer-normalization to its input tensor. In its</span>
<span class="sd"> simplest form, for large language models, the &#39;normalized_shape&#39; should be</span>
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
<span class="sd"> right-most dimension).</span>
<span class="sd"> The &#39;weight&#39; tensor corresponds to &#39;gamma&#39; in the layer-norm formula and</span>
<span class="sd"> &#39;bias&#39; is &#39;beta&#39;. The &#39;eps&#39; value is added to the variance before computing</span>
<span class="sd"> the squared-root.</span>
<span class="sd"> This implementation (when using the plugin) supports an additional flag to</span>
<span class="sd"> enable/disable the use of a difference of squares (&#39;Var = Mean(X^2) -</span>
<span class="sd"> Mean(X)^2&#39;).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The tensor to normalize.</span>
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
<span class="sd"> The shape of the sub-tensor that is normalized. Use &#39;hidden_dim&#39; to</span>
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
<span class="sd"> weight : Optional[Tensor] = None</span>
<span class="sd"> The &#39;gamma&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> bias : Optional[Tensor] = None</span>
<span class="sd"> The &#39;beta&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> eps : float</span>
<span class="sd"> The epsilon term to be added to the variance in the squared-root.</span>
<span class="sd"> use_diff_of_squares : bool</span>
<span class="sd"> Does the plugin use the difference of squares to compute the</span>
<span class="sd"> variance?</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of that operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weight</span><span class="p">)</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">broadcast_helper</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">bias</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="c1"># FIXME: better way?</span>
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">axis</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span>
<span class="n">axes_mask</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()):</span>
<span class="n">axes_mask</span> <span class="o">|=</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="n">i</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_normalization</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">bias</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">axes_mask</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">eps</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="rms_norm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rms_norm">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">rms_norm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">normalized_shape</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
<span class="n">num_groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-06</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a RMS norm operation on a tensor.</span>
<span class="sd"> That operation applies the rms-normalization to its input tensor. In its</span>
<span class="sd"> simplest form, for large language models, the &#39;normalized_shape&#39; should be</span>
<span class="sd"> set to the hidden dimension of the activation tensor. Otherwise, it is the</span>
<span class="sd"> shape of the normalized fraction of the tensor (starting from the</span>
<span class="sd"> right-most dimension).</span>
<span class="sd"> The &#39;weight&#39; tensor corresponds to &#39;gamma&#39; in the rms-norm formula.</span>
<span class="sd"> The &#39;eps&#39; value is added to the variance before computing the squared-root.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The tensor to normalize.</span>
<span class="sd"> normalized_shape : Union[int, Tuple[int]]</span>
<span class="sd"> The shape of the sub-tensor that is normalized. Use &#39;hidden_dim&#39; to</span>
<span class="sd"> normalize the inner-most dimension of an activation tensor in LLMs.</span>
<span class="sd"> num_groups: int = 1</span>
<span class="sd"> The group size.</span>
<span class="sd"> weight : Optional[Tensor] = None</span>
<span class="sd"> The &#39;gamma&#39; term in layer-norm. Its shape must be</span>
<span class="sd"> &#39;normalized_shape&#39;.</span>
<span class="sd"> eps : float</span>
<span class="sd"> The epsilon term to be added to the variance in the squared-root.weig</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of that operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">normalized_shape</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalized_shape</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
<span class="n">normalized_shape</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> <span class="n">normalized_shape</span>
<span class="n">dim</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="o">-</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">))])</span>
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">normalized_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
<span class="n">num_channels</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="n">old_shape</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span>
<span class="n">new_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">+</span>
<span class="p">[</span><span class="n">num_groups</span><span class="p">,</span> <span class="n">num_channels</span> <span class="o">//</span> <span class="n">num_groups</span><span class="p">])</span>
<span class="nb">input</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">new_shape</span><span class="p">)</span>
<span class="k">with</span> <span class="n">precision</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">):</span>
<span class="n">input_dtype</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">fp32_input</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
<span class="n">varx</span> <span class="o">=</span> <span class="nb">pow</span><span class="p">(</span><span class="n">fp32_input</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)</span>
<span class="n">varx</span> <span class="o">=</span> <span class="n">varx</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">varx</span> <span class="o">+</span> <span class="n">eps</span>
<span class="n">denom</span> <span class="o">=</span> <span class="n">denom</span><span class="o">.</span><span class="n">sqrt</span><span class="p">()</span>
<span class="n">fp32_y</span> <span class="o">=</span> <span class="n">fp32_input</span> <span class="o">/</span> <span class="n">denom</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">fp32_y</span><span class="p">,</span> <span class="n">input_dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">num_groups</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">old_shape</span><span class="p">)</span>
<span class="k">if</span> <span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="n">weight</span>
<span class="k">return</span> <span class="n">y</span></div>
<div class="viewcode-block" id="rearrange">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rearrange">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">rearrange</span><span class="p">(</span><span class="n">inputs</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]],</span> <span class="n">expression</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a rearrange operation on a tensor.</span>
<span class="sd"> This operation is a reader-friendly smart element reordering for multidimensional tensors,</span>
<span class="sd"> including functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,</span>
<span class="sd"> stack, concatenate and other operations. Please see: https://einops.rocks/api/rearrange/</span>
<span class="sd"> For example, if the shape of input tensor is [32, 30, 40, 3], and run:</span>
<span class="sd"> `rearrange(x, &#39;b (h h1) (w w1) c -&gt; b h w 1 (c h1 w1) 1&#39;, h1=2, w1=2)`</span>
<span class="sd"> it would produce a tensor with shape as [32, 15, 20, 1, 12, 1].</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Union[Tensor, Sequence[Tensor]]</span>
<span class="sd"> If it is a tensor, it will directly operate on it.</span>
<span class="sd"> Otherwise, if it is a sequence, it will concat it to a tensor and then</span>
<span class="sd"> operates on it.</span>
<span class="sd"> expression : str</span>
<span class="sd"> The expression about how to reorder the tensor in a reader-friendly way.</span>
<span class="sd"> kwargs:</span>
<span class="sd"> Keyword arguments to set some identifiers with specific values.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output tensor of this operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">re</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init_expression</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
<span class="n">expr_items</span> <span class="o">=</span> <span class="n">expr</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">)</span>
<span class="n">tmp_name_index</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">expr_items</span><span class="p">):</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;\b\d+\b&#39;</span><span class="p">,</span> <span class="n">item</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">values</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">prefix</span> <span class="o">=</span> <span class="s2">&quot;(&quot;</span> <span class="k">if</span> <span class="s2">&quot;(&quot;</span> <span class="ow">in</span> <span class="n">item</span> <span class="k">else</span> <span class="s2">&quot;&quot;</span>
<span class="n">subfix</span> <span class="o">=</span> <span class="s2">&quot;)&quot;</span> <span class="k">if</span> <span class="s2">&quot;)&quot;</span> <span class="ow">in</span> <span class="n">item</span> <span class="k">else</span> <span class="s2">&quot;&quot;</span>
<span class="n">expr_items</span><span class="p">[</span>
<span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">prefix</span><span class="si">}</span><span class="s2">NumericId</span><span class="si">{</span><span class="n">tmp_name_index</span><span class="si">}</span><span class="s2">Val</span><span class="si">{</span><span class="n">values</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}{</span><span class="n">subfix</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">tmp_name_index</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">expr_items</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_all_identifier</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
<span class="k">return</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;\b[a-zA-Z_]+\d*\b&#39;</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_all_symbols</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
<span class="k">return</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;\b\w+\b&#39;</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_get_dim_expr</span><span class="p">(</span><span class="n">expr</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span>
<span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">match</span><span class="o">.</span><span class="n">group</span><span class="p">())</span>
<span class="k">for</span> <span class="n">match</span> <span class="ow">in</span> <span class="n">re</span><span class="o">.</span><span class="n">finditer</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;\b\w+\b|\(.*?\)&#39;</span><span class="p">,</span> <span class="n">expr</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">src_shape_expr</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">dst_shape_expr</span> <span class="o">=</span> <span class="n">expression</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">&quot;-&gt;&quot;</span><span class="p">)</span>
<span class="n">unknown_identifiers</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;[^a-zA-Z0-9_\(\)]&#39;</span><span class="p">,</span>
<span class="n">src_shape_expr</span> <span class="o">+</span> <span class="n">dst_shape_expr</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span>
<span class="n">unknown_identifiers</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Unknown identifiers: </span><span class="si">{</span><span class="n">unknown_identifiers</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">src_identifiers</span> <span class="o">=</span> <span class="n">_get_all_identifier</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)</span>
<span class="n">dst_identifiers</span> <span class="o">=</span> <span class="n">_get_all_identifier</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">)</span>
<span class="k">assert</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">))</span>
<span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">))</span>
<span class="p">),</span> <span class="s2">&quot;Indexing expression contains duplicate dimension.&quot;</span>
<span class="k">assert</span> <span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)</span> <span class="o">==</span> <span class="nb">set</span><span class="p">(</span><span class="n">dst_identifiers</span><span class="p">)</span>
<span class="p">),</span> <span class="s2">&quot;Identifiers only on one side of expression (should be on both).&quot;</span>
<span class="n">new_expression</span> <span class="o">=</span> <span class="n">_init_expression</span><span class="p">(</span><span class="n">expression</span><span class="p">)</span>
<span class="n">src_shape_expr</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">dst_shape_expr</span> <span class="o">=</span> <span class="n">new_expression</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">&quot;-&gt;&quot;</span><span class="p">)</span>
<span class="c1"># concat if inputs are sequence of tensors</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">):</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">inputs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">))</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;inputs.ndim() is </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span><span class="si">}</span><span class="s2"> while indexing expression has </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">))</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">src_symbols</span> <span class="o">=</span> <span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)</span>
<span class="n">dst_symbols</span> <span class="o">=</span> <span class="n">_get_all_symbols</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">)</span>
<span class="c1"># find all the symbols-values mapping and store them in symbol_map</span>
<span class="n">symbol_map</span> <span class="o">=</span> <span class="p">{</span>
<span class="n">symbol</span><span class="p">:</span> <span class="p">{</span>
<span class="s2">&quot;updated&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
<span class="s2">&quot;value&quot;</span><span class="p">:</span> <span class="kc">None</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="nb">set</span><span class="p">(</span><span class="n">src_symbols</span> <span class="o">+</span> <span class="n">dst_symbols</span><span class="p">)</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">symbol_map</span><span class="p">:</span>
<span class="k">if</span> <span class="s2">&quot;NumericId&quot;</span> <span class="ow">in</span> <span class="n">symbol</span><span class="p">:</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">symbol</span><span class="o">.</span><span class="n">partition</span><span class="p">(</span><span class="s2">&quot;Val&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">dim_expr</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">src_shape_expr</span><span class="p">)):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim_expr</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">symbol</span> <span class="o">=</span> <span class="n">dim_expr</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]:</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">divisors</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">unknown_symbol</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dim_expr</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]:</span>
<span class="n">unknown_symbol</span> <span class="o">=</span> <span class="n">symbol</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">divisors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">])</span>
<span class="k">if</span> <span class="n">unknown_symbol</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">divisors</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="n">divisor</span> <span class="o">=</span> <span class="n">prod</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">divisors</span><span class="p">),</span> <span class="s2">&quot;int64&quot;</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">unknown_symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span>
<span class="n">idx</span><span class="p">)</span> <span class="o">/</span> <span class="n">divisor</span>
<span class="n">symbol_map</span><span class="p">[</span><span class="n">unknown_symbol</span><span class="p">][</span><span class="s2">&quot;updated&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">symbol</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">symbol_map</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">item</span><span class="p">[</span><span class="s2">&quot;updated&quot;</span><span class="p">]</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">symbol</span><span class="si">}</span><span class="s2"> cannot be inferred, please set it manually&quot;</span>
<span class="n">dst_dims</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">dim_expr</span> <span class="ow">in</span> <span class="n">_get_dim_expr</span><span class="p">(</span><span class="n">dst_shape_expr</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dim_expr</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">dst_dims</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">dim_expr</span><span class="p">[</span><span class="mi">0</span><span class="p">]][</span><span class="s2">&quot;value&quot;</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">accumulator</span> <span class="o">=</span> <span class="n">prod</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dim_expr</span><span class="p">]),</span>
<span class="s2">&quot;int64&quot;</span><span class="p">),</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">dst_dims</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">accumulator</span><span class="p">)</span>
<span class="n">dst_dims</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">dst_dims</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;int64&quot;</span><span class="p">)</span>
<span class="n">src_indices</span> <span class="o">=</span> <span class="p">{</span><span class="n">symbol</span><span class="p">:</span> <span class="n">idx</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">src_identifiers</span><span class="p">)}</span>
<span class="n">permute_dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">src_indices</span><span class="p">[</span><span class="n">symbol</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">dst_identifiers</span><span class="p">]</span>
<span class="n">symbol_shape</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">symbol_map</span><span class="p">[</span><span class="n">symbol</span><span class="p">][</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">symbol</span> <span class="ow">in</span> <span class="n">src_identifiers</span><span class="p">],</span>
<span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;int64&quot;</span><span class="p">)</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">symbol_shape</span><span class="p">)</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">permute</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">permute_dims</span><span class="p">)</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">dst_dims</span><span class="p">)</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<div class="viewcode-block" id="repeat">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">repeat</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">sizes</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Repeats the tensor along the specified dimensions.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The tensor to be repeated.</span>
<span class="sd"> sizes : Sequence[int]</span>
<span class="sd"> The number of times to repeat the tensor along each dimension.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor except for repeated input tensors along specified dim.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">&lt;=</span> <span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">),</span> \
<span class="s2">&quot;Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor&quot;</span>
<span class="n">repeated_tensor</span> <span class="o">=</span> <span class="nb">input</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="nb">len</span><span class="p">(</span><span class="n">sizes</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
<span class="n">repeated_tensor</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">repeated_tensor</span><span class="p">]</span> <span class="o">*</span> <span class="n">sizes</span><span class="p">[</span><span class="n">k</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="n">k</span><span class="p">)</span>
<span class="k">return</span> <span class="n">repeated_tensor</span></div>
<div class="viewcode-block" id="repeat_interleave">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.repeat_interleave">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">repeat_interleave</span><span class="p">(</span><span class="n">tensor</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">repeats</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Repeats elements of a tensor along an axis.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> repeats : int</span>
<span class="sd"> The number of repetitions along axis specified.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The dimension along which repetitions are performed.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A tensor with the same shape as input except for repeated elements along specified dim.</span>
<span class="sd"> TODO: Allow repeats to be a list of integers and dim to be unspecified.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">expanded_tensor</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">tile_output_size</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
<span class="n">repeats</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span> <span class="n">shape</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())</span>
<span class="p">])</span>
<span class="n">tile</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">expanded_tensor</span><span class="p">,</span> <span class="n">tile_output_size</span><span class="p">)</span>
<span class="n">tile_reshape_size</span> <span class="o">=</span> <span class="p">[</span><span class="n">shape</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">tensor</span><span class="o">.</span><span class="n">ndim</span><span class="p">())]</span>
<span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">tile_reshape_size</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">*</span> <span class="n">repeats</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">tile</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">(</span><span class="n">tile_reshape_size</span><span class="p">))</span>
<span class="k">return</span> <span class="n">tensor</span></div>
<div class="viewcode-block" id="meshgrid2d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.meshgrid2d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">meshgrid2d</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Creates grids (2D) of coordinates specified by the 1D inputs (only supports `indexing=\&#39;xy\&#39;`).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> x : Tensor</span>
<span class="sd"> The first input (1D) tensor.</span>
<span class="sd"> y : Tensor</span>
<span class="sd"> The second input (1D) tensor.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tuple of two tensors produced.</span>
<span class="sd"> TODO: Add full support for torch.meshgrid.</span>
<span class="sd"> See https://pytorch.org/docs/stable/generated/torch.meshgrid.html#torch-meshgrid</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">if</span> <span class="n">y</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">expand_dims</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">grid_x</span> <span class="o">=</span> <span class="n">repeat_interleave</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">shape</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
<span class="n">grid_y</span> <span class="o">=</span> <span class="n">repeat</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">return</span> <span class="p">(</span><span class="n">grid_x</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">)</span></div>
<div class="viewcode-block" id="generate_logn_scaling">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_logn_scaling">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">generate_logn_scaling</span><span class="p">(</span><span class="n">seq_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8192</span><span class="p">,</span>
<span class="n">max_position_embeddings</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32768</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Compute the Log-N scaling vector for Qwen inference extrapolation</span>
<span class="sd"> Parameters:</span>
<span class="sd"> seq_length : int</span>
<span class="sd"> The max seq length in training (default to 8192 in Qwen-1)</span>
<span class="sd"> max_position_embeddings : int</span>
<span class="sd"> The max position embeddings. (default to 32768 in Qwen-1)</span>
<span class="sd"> Returns:</span>
<span class="sd"> A constant np.ndarray that contains logn scaling vector</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">logn_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">seq_length</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;</span> <span class="n">seq_length</span> <span class="k">else</span> <span class="mi">1</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_position_embeddings</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="p">]</span>
<span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">logn_list</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></div>
<div class="viewcode-block" id="generate_alibi_slopes">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_slopes">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">generate_alibi_slopes</span><span class="p">(</span><span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">tp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">alibi_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">alibi_bias_max</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Compute the ALiBi slopes as described in https://arxiv.org/abs/2211.05100.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> num_heads : int</span>
<span class="sd"> The number of heads.</span>
<span class="sd"> dtype : trt.DataType</span>
<span class="sd"> The data type of the returned slopes</span>
<span class="sd"> tp_size : int</span>
<span class="sd"> The tensor parallelism size</span>
<span class="sd"> tp_rank : int</span>
<span class="sd"> The tensor parallelism rank</span>
<span class="sd"> Returns:</span>
<span class="sd"> A constant tensor that contains the ALiBi slopes.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">start_head_id</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="k">if</span> <span class="n">tp_size</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">rank_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">tp_size</span>
<span class="n">start_head_id</span> <span class="o">=</span> <span class="n">rank_heads</span> <span class="o">*</span> <span class="n">tp_rank</span>
<span class="n">end_head_id</span> <span class="o">=</span> <span class="n">start_head_id</span> <span class="o">+</span> <span class="n">rank_heads</span>
<span class="n">closest_power_of_2</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_heads</span><span class="p">))</span>
<span class="c1"># FT&#39;s implementation</span>
<span class="c1"># https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/gen_relative_pos_bias.cu#L248</span>
<span class="n">slopes_ft</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">h_id</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_head_id</span><span class="p">,</span> <span class="n">end_head_id</span><span class="p">):</span>
<span class="k">if</span> <span class="n">h_id</span> <span class="o">&lt;</span> <span class="n">closest_power_of_2</span><span class="p">:</span>
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">-</span>
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span> <span class="n">h_id</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">slopes_ft</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="n">np</span><span class="o">.</span><span class="n">power</span><span class="p">(</span>
<span class="mi">2</span><span class="o">**</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="mi">2</span><span class="o">**-</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">closest_power_of_2</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">-</span>
<span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">alibi_bias_max</span><span class="p">)))),</span>
<span class="p">(</span><span class="n">h_id</span> <span class="o">-</span> <span class="n">closest_power_of_2</span><span class="p">)</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">slopes_ft</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">alibi_scale</span> <span class="o">*</span> <span class="n">slopes</span>
<span class="n">slopes</span> <span class="o">=</span> <span class="n">slopes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="n">end_head_id</span> <span class="o">-</span> <span class="n">start_head_id</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">slopes</span></div>
<div class="viewcode-block" id="generate_alibi_biases">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.generate_alibi_biases">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">generate_alibi_biases</span><span class="p">(</span><span class="n">slopes</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">key_length</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Compute the ALiBi biases as described in https://arxiv.org/abs/2211.05100.</span>
<span class="sd"> The ALiBi biases are added to the result of the Q*K^T product in the</span>
<span class="sd"> multi-head attention block.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> slopes : Tensor</span>
<span class="sd"> The slopes.</span>
<span class="sd"> key_length : Tensor</span>
<span class="sd"> The size of the K vector per head.</span>
<span class="sd"> Returns:</span>
<span class="sd"> A constant tensor that contains the ALiBi biases.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="c1"># We don&#39;t need to care about the batch size or query length since we can just broadcast</span>
<span class="c1"># across the batch and query dimensions</span>
<span class="n">trt_0</span> <span class="o">=</span> <span class="n">constant</span><span class="p">(</span><span class="n">int32_array</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="n">arange_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">key_length</span><span class="p">])</span>
<span class="n">arange_tensor</span> <span class="o">=</span> <span class="n">arange</span><span class="p">(</span><span class="n">trt_0</span><span class="p">,</span> <span class="n">key_length</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">arange_shape</span><span class="p">)</span>
<span class="k">return</span> <span class="n">slopes</span> <span class="o">*</span> <span class="n">arange_tensor</span></div>
<div class="viewcode-block" id="expand_mask">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.expand_mask">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">expand_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Expand an attention mask.</span>
<span class="sd"> That function adds the sequence of operations to expand from a tensor of</span>
<span class="sd"> shape &#39;[batch_size, src_seq_len]&#39; to a tensor of shape</span>
<span class="sd"> &#39;[batch_size, 1, tgt_seq_len, src_seq_len]&#39;. It can be used to create the</span>
<span class="sd"> mask applied to the Q*K^T product before the softmax operation in the</span>
<span class="sd"> multi-head attention block.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> mask : Tensor</span>
<span class="sd"> The input mask</span>
<span class="sd"> tgt_len : Optional[Tensor]</span>
<span class="sd"> The dimension of the 3rd dimension in the output tensor. If None,</span>
<span class="sd"> the 2nd dimension of the input is used.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor created by that sequence of operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">bsz</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">src_len</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">tgt_len</span> <span class="o">=</span> <span class="n">tgt_len</span> <span class="k">if</span> <span class="n">tgt_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">src_len</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">concat</span><span class="p">([</span><span class="n">bsz</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">tgt_len</span><span class="p">,</span> <span class="n">src_len</span><span class="p">]))</span>
<span class="n">mask</span> <span class="o">=</span> <span class="n">where</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;-inf&#39;</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mask</span></div>
<div class="viewcode-block" id="gather_last_token_logits">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.gather_last_token_logits">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">gather_last_token_logits</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Extract the logits that correspond to the last token from the hidden states.</span>
<span class="sd"> That function adds the operations to extract the logits of the last tokens</span>
<span class="sd"> in a batch of sequences.</span>
<span class="sd"> Depending on whether &#39;remove_input_padding&#39; is &#39;True&#39; or &#39;False&#39;, that</span>
<span class="sd"> function assumes inputs of different shapes.</span>
<span class="sd"> When &#39;remove_input_padding&#39; is &#39;True&#39;, the &#39;hidden_states&#39; tensor is</span>
<span class="sd"> assumed to be packed. It has a shape &#39;[num_tokens, hidden_dim]&#39; where</span>
<span class="sd"> &#39;num_tokens&#39; is the sum of the lengths of the sequences in the batch and</span>
<span class="sd"> &#39;hidden_dim&#39; is the hidden dimension. The &#39;last_tokens_ids&#39; is a 1D tensor</span>
<span class="sd"> that encodes the inclusive prefix-sums of the lengths of the sequences in</span>
<span class="sd"> the batch.</span>
<span class="sd"> When &#39;remove_input_padding&#39; is &#39;False&#39;, the &#39;hidden_states&#39; tensor is</span>
<span class="sd"> assumed to be padded. It has a shape &#39;[batch_size, max_seqlen, hidden_dim]&#39;</span>
<span class="sd"> where &#39;max_seqlen&#39; is the length of the longest sequence in the batch and</span>
<span class="sd"> &#39;hidden_dim&#39; is the hidden dimension. The &#39;last_token_ids&#39; is a 1D tensor</span>
<span class="sd"> that encodes the length of each sequence in the batch.</span>
<span class="sd"> In both cases, that function produces a tensor of shape &#39;[batch_size,</span>
<span class="sd"> hidden_size]&#39; where the row at index &#39;i&#39; corresponds to the logits of the</span>
<span class="sd"> last token from the &#39;i&#39;-th sequence.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> hidden_states : Tensor</span>
<span class="sd"> The hidden states</span>
<span class="sd"> last_token_ids : Tensor</span>
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
<span class="sd"> sequences in the batch.</span>
<span class="sd"> remove_input_padding : bool</span>
<span class="sd"> Indicate if the hidden_states are packed (&#39;True&#39;) or padded</span>
<span class="sd"> (&#39;False&#39;).</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor created by that sequence of operations.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">last_token_ids</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">return</span> <span class="n">hidden_states</span>
<span class="k">if</span> <span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">index_select</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># [seq_len, hidden]</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">hidden_states</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">1</span><span class="p">)]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">ndim</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span>
<span class="k">if</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># only calculate logits for the last token</span>
<span class="c1"># [batch_size, seqlen, hidden_size] -&gt; [batch_size, hidden_size]</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span>
<span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)]))</span>
<span class="k">elif</span> <span class="n">ndim</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="c1"># speculative decoding needs last few token&#39;s logits</span>
<span class="c1"># last_token_ids is of shape [batch_size, num_last_tokens]</span>
<span class="c1"># So [batch_size, seqlen, hidden_size] -&gt; [batch_size, num_last_tokens, hidden_size]</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">last_token_ids</span><span class="o">.</span><span class="n">view</span><span class="p">(</span>
<span class="n">concat</span><span class="p">([</span><span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">last_token_ids</span> <span class="o">=</span> <span class="n">expand</span><span class="p">(</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="n">concat</span><span class="p">([</span>
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">last_token_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">shape</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="p">]))</span>
<span class="n">hidden_states</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="n">hidden_states</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">indices</span><span class="o">=</span><span class="n">last_token_ids</span><span class="p">)</span>
<span class="k">return</span> <span class="n">hidden_states</span></div>
<span class="n">ACT2FN</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;relu&#39;</span><span class="p">:</span> <span class="n">relu</span><span class="p">,</span>
<span class="s1">&#39;tanh&#39;</span><span class="p">:</span> <span class="n">tanh</span><span class="p">,</span>
<span class="s1">&#39;gelu&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;gelu_new&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;gelu_fast&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;gelu_pytorch_tanh&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;openai-gelu&#39;</span><span class="p">:</span> <span class="n">gelu</span><span class="p">,</span>
<span class="s1">&#39;geglu&#39;</span><span class="p">:</span> <span class="n">geglu</span><span class="p">,</span>
<span class="s1">&#39;gegelu&#39;</span><span class="p">:</span> <span class="n">gegelu</span><span class="p">,</span>
<span class="s1">&#39;identity&#39;</span><span class="p">:</span> <span class="n">identity</span><span class="p">,</span>
<span class="s1">&#39;silu&#39;</span><span class="p">:</span> <span class="n">silu</span><span class="p">,</span>
<span class="s1">&#39;softplus&#39;</span><span class="p">:</span> <span class="n">softplus</span><span class="p">,</span>
<span class="s1">&#39;relu2&#39;</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
<span class="s1">&#39;squared-relu&#39;</span><span class="p">:</span> <span class="n">squared_relu</span><span class="p">,</span>
<span class="s1">&#39;swiglu&#39;</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
<span class="s1">&#39;fast-swiglu&#39;</span><span class="p">:</span> <span class="n">swiglu</span><span class="p">,</span>
<span class="s1">&#39;sigmoid&#39;</span><span class="p">:</span> <span class="n">sigmoid</span><span class="p">,</span>
<span class="s1">&#39;quick_gelu&#39;</span><span class="p">:</span> <span class="n">quick_gelu</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">GATED_ACT_2_ACT</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;swiglu&#39;</span><span class="p">:</span> <span class="s1">&#39;silu&#39;</span><span class="p">,</span>
<span class="s1">&#39;fast-swiglu&#39;</span><span class="p">:</span> <span class="s1">&#39;silu&#39;</span><span class="p">,</span>
<span class="s1">&#39;geglu&#39;</span><span class="p">:</span> <span class="s1">&#39;gelu&#39;</span><span class="p">,</span>
<span class="p">}</span>
<div class="viewcode-block" id="is_gated_activation">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.is_gated_activation">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Is a given activation function gated?</span>
<span class="sd"> Parameters:</span>
<span class="sd"> activation : str</span>
<span class="sd"> The name of the activation function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> True if the function is gated, False otherwise.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">ACT2FN</span>
<span class="k">return</span> <span class="n">activation</span> <span class="ow">in</span> <span class="n">GATED_ACT_2_ACT</span></div>
<div class="viewcode-block" id="non_gated_version">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.non_gated_version">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">non_gated_version</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Given an activation function, get the non-gated version.</span>
<span class="sd"> If the activation function is non-gated, it returns the same activation</span>
<span class="sd"> function name.</span>
<span class="sd"> For example, that function returns &#39;silu&#39; for &#39;swiglu&#39; and &#39;relu&#39; for</span>
<span class="sd"> &#39;relu&#39;.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> activation : str</span>
<span class="sd"> The name of the activation function.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The name of the non-gated activation function.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">if</span> <span class="n">is_gated_activation</span><span class="p">(</span><span class="n">activation</span><span class="p">):</span>
<span class="k">return</span> <span class="n">GATED_ACT_2_ACT</span><span class="p">[</span><span class="n">activation</span><span class="p">]</span>
<span class="k">return</span> <span class="n">activation</span></div>
<div class="viewcode-block" id="lora_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.lora_plugin">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">lora_plugin</span><span class="p">(</span>
<span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">in_hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">out_hidden_sizes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">],</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">transa</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">transb</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
<span class="n">max_low_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">lora_ranks</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weight_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
<span class="sd"> in_hidden_size/out_hidden_size : int</span>
<span class="sd"> the lora computation workflow is</span>
<span class="sd"> [M, in_hidden_size] -&gt; [M, low_rank] -&gt; [M, out_hidden_size]</span>
<span class="sd"> host_request_types : Tensor = None</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> transa : bool</span>
<span class="sd"> Is the first input transposed? Set to &#39;True&#39; if you want the first</span>
<span class="sd"> input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> transb : bool</span>
<span class="sd"> Is the second input transposed? Set to &#39;True&#39; if you want the</span>
<span class="sd"> second input to be transposed, &#39;False&#39; otherwise.</span>
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> max_low_rank : int</span>
<span class="sd"> Maximum low_rank, used to determine the workspace size.</span>
<span class="sd"> lora_ranks : cpu Tensor with shape [batch_size]</span>
<span class="sd"> The low_rank of each request</span>
<span class="sd"> lora_weights_pointers : cpu int64 Tensor with shape [batch_size, 3]</span>
<span class="sd"> The weights pointers of each request. Consist of in_pointer, out_pointer and possibly a scales vector pointer.</span>
<span class="sd"> weight_index : int</span>
<span class="sd"> The index of weight if the weight pointer pointing to multiple weights.</span>
<span class="sd"> Return:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_creator_list</span>
<span class="n">in_hidden_size_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;in_hidden_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">in_hidden_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">out_hidden_size_field_list</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;out_hidden_size_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">o</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">o</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">transa</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transa</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transa</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transa&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transa</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">transb</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">transb</span> <span class="k">else</span> <span class="mi">0</span>
<span class="n">transb</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;transb&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">transb</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;Lora&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">max_low_rank_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;max_low_rank&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">max_low_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">weight_index_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;weight_index&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">weight_index</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">num_lora_modules</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">)</span>
<span class="n">num_lora_modules_field</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;num_lora_modules&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">in_hidden_size_field</span><span class="p">,</span> <span class="n">transa</span><span class="p">,</span> <span class="n">transb</span><span class="p">,</span> <span class="n">num_lora_modules_field</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">max_low_rank_field</span><span class="p">,</span> <span class="n">weight_index_field</span>
<span class="p">]</span> <span class="o">+</span> <span class="n">out_hidden_size_field_list</span><span class="p">)</span>
<span class="n">lora_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;lora&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">),</span> <span class="n">host_request_types</span>
<span class="p">]</span> <span class="o">+</span> <span class="n">lora_ranks</span> <span class="o">+</span> <span class="n">lora_weights_pointers</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lora_plug</span><span class="p">)</span>
<span class="k">if</span> <span class="n">num_lora_modules</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="p">[</span>
<span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="n">i</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_lora_modules</span><span class="p">)</span>
<span class="p">]</span></div>
<div class="viewcode-block" id="dora_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.dora_plugin">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">dora_plugin</span><span class="p">(</span><span class="n">activations</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">out_hidden_sizes</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="n">lora_weights_pointers</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> The DoRA plugin applies column-wise scaling to the output of a LoRA layer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
<span class="sd"> out_hidden_sizes : list[int]</span>
<span class="sd"> The output hidden size of each adapter in the related LoRA module.</span>
<span class="sd"> For example, for a qkv projection out_hidden_sizes should be [q_dim, k_dim, v_dim].</span>
<span class="sd"> host_request_types : Tensor = None</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> host_context_lengths: cpu Tensor = None</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> Return:</span>
<span class="sd"> The tensor produced by that layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_context_lengths</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">(</span>
<span class="p">)</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span>
<span class="n">dora_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_creator</span><span class="p">(</span>
<span class="s1">&#39;Dora&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">dora_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">out_hidden_sizes</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;out_hidden_sizes&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">out_hidden_sizes</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">lora_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span>
<span class="n">type_id</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">lora_dtype</span><span class="p">)),</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
<span class="p">[</span><span class="n">type_id</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">out_hidden_sizes</span><span class="p">])</span>
<span class="n">dora_plug</span> <span class="o">=</span> <span class="n">dora_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;dora&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">TensorRTPhase</span><span class="o">.</span><span class="n">BUILD</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">activations</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">lora_dtype</span><span class="p">),</span> <span class="n">host_request_types</span>
<span class="p">]</span> <span class="o">+</span> <span class="n">lora_weights_pointers</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v3</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="p">[],</span> <span class="n">dora_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">dora_plg_creator</span><span class="p">,</span> <span class="s2">&quot;dora&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">activations</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="mamba_conv1d">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.mamba_conv1d">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">mamba_conv1d</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">conv_state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">conv_weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">conv_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dconv</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">pre_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">post_stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">apply_silu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
<span class="sd"> conv_state_or_ptr : Tensor (On GPU or CPU)</span>
<span class="sd"> The conv state tensor. Its shape is [batch_size, dconv - 1, dim]</span>
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
<span class="sd"> conv_weight : Tensor (On GPU)</span>
<span class="sd"> The weight tensor. Its shape is [1, dconv, dim]</span>
<span class="sd"> conv_bias : Tensor (On GPU)</span>
<span class="sd"> The bias tensor. Its shape is [dim]</span>
<span class="sd"> host_request_types : Tensor (On CPU)</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
<span class="sd"> sequences in the batch.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The hidden dimension of conv1d</span>
<span class="sd"> dconv : int</span>
<span class="sd"> The window size of conv1d</span>
<span class="sd"> dtype: str</span>
<span class="sd"> data type</span>
<span class="sd"> pre_stride : int = 0</span>
<span class="sd"> The (pre) stride size of the input tensor.</span>
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
<span class="sd"> post_stride : int = 0</span>
<span class="sd"> The (post) stride size of the input tensor.</span>
<span class="sd"> The valid values of the input tensor are input[..., pre_stride: dim-post_stride]</span>
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dconv, dim]</span>
<span class="sd"> apply_silu: bool</span>
<span class="sd"> Is there a SiLU operation after the conv1d? When True apply</span>
<span class="sd"> SiLU activation function after the conv1d.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">mamba_conv1d_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;MambaConv1d&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">mamba_conv1d_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">dconv</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dconv&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dconv</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pre_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;pre_stride&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">pre_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">post_stride</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;post_stride&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">post_stride</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;paged_state&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">apply_silu</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;apply_silu&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">apply_silu</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">dim</span><span class="p">,</span> <span class="n">dconv</span><span class="p">,</span> <span class="n">pre_stride</span><span class="p">,</span> <span class="n">post_stride</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">paged_state</span><span class="p">,</span> <span class="n">apply_silu</span>
<span class="p">])</span>
<span class="n">mamba_conv1d_plug</span> <span class="o">=</span> <span class="n">mamba_conv1d_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
<span class="s2">&quot;mamba_conv1d&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">conv_state_or_ptr</span><span class="p">,</span> <span class="n">conv_weight</span><span class="p">,</span> <span class="n">conv_bias</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
<span class="n">last_token_ids</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">mamba_conv1d_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">mamba_conv1d_plg_creator</span><span class="p">,</span> <span class="s2">&quot;mamba_conv1d&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
<div class="viewcode-block" id="selective_scan">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.selective_scan">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">selective_scan</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">delta</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">delta_bias</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">BC</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">D</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dstate</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dt_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">delta_softplus</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">z</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">nheads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">ngroups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">chunk_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">mamba_version</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;Mamba1&#39;</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
<span class="sd"> The ssm state tensor. Its shape is [batch_size, dstate, dim]</span>
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
<span class="sd"> delta : Tensor (On GPU)</span>
<span class="sd"> The delta tensor.</span>
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, nheads] or [num_tokens, nheads] for remove_input_padding</span>
<span class="sd"> delta_bias : Tensor (On GPU)</span>
<span class="sd"> The delta bias tensor.</span>
<span class="sd"> mamba: Its shape is [dim]</span>
<span class="sd"> mamba2: Its shape is [nheads]</span>
<span class="sd"> A : Tensor (On GPU)</span>
<span class="sd"> A matrix.</span>
<span class="sd"> mamba: Its shape is [dstate, dim]</span>
<span class="sd"> mamba2: Its shape is [nheads]</span>
<span class="sd"> BC : Tensor (On GPU)</span>
<span class="sd"> B and C matrix.</span>
<span class="sd"> mamba: Its shape is [batch_size, seq_len, dstate * 2] or [num_tokens, dstate * 2] for remove_input_padding</span>
<span class="sd"> mamba2: Its shape is [batch_size, seq_len, ngroups * dstate * 2] or [num_tokens, ngroups * dstate * 2] for remove_input_padding</span>
<span class="sd"> D : Tensor (On GPU)</span>
<span class="sd"> D matrix.</span>
<span class="sd"> mamba: Its shape is [dim]</span>
<span class="sd"> mamba2: Its shape is [nheads]</span>
<span class="sd"> host_request_types : Tensor (On CPU)</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md</span>
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
<span class="sd"> sequences in the batch.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The inner dimension of SSM block</span>
<span class="sd"> dstate : int</span>
<span class="sd"> The state dimension of SSM block</span>
<span class="sd"> dt_rank: int</span>
<span class="sd"> The rank dimension of dt_proj</span>
<span class="sd"> delta_softplus : bool</span>
<span class="sd"> Do we apply softplus to the delta.</span>
<span class="sd"> dtype: str</span>
<span class="sd"> data type</span>
<span class="sd"> z : Tensor (On GPU) (Optional)</span>
<span class="sd"> The z tensor. Its shape is [batch_size, seq_len, dim] or [num_tokens, dim] for remove_input_padding</span>
<span class="sd"> host_context_lengths: Tensor (On CPU) (Optional)</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs,</span>
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
<span class="sd"> nheads: int (Optional)</span>
<span class="sd"> The number of heads.</span>
<span class="sd"> ngroups: int (Optional)</span>
<span class="sd"> The number of groups.</span>
<span class="sd"> chunk_size: int (Optional)</span>
<span class="sd"> The chunk_size is used for the chunk_scan kernel.</span>
<span class="sd"> mamba_version: int (Optional)</span>
<span class="sd"> Mamba version, support Mamba1 as default.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">selective_scan_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;SelectiveScan&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">selective_scan_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">dstate</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dstate&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dstate</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">dt_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dt_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dt_rank</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">nheads</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;nheads&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">nheads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">ngroups</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;ngroups&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">ngroups</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">chunk_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;chunk_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">chunk_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">delta_softplus</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;delta_softplus&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">delta_softplus</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;paged_state&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;z_enabled&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">z_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;z_enabled&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">is_mamba2</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;is_mamba2&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">mamba_version</span> <span class="o">==</span> <span class="s1">&#39;Mamba2&#39;</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">dim</span><span class="p">,</span> <span class="n">dstate</span><span class="p">,</span> <span class="n">dt_rank</span><span class="p">,</span> <span class="n">nheads</span><span class="p">,</span> <span class="n">ngroups</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">delta_softplus</span><span class="p">,</span>
<span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">z_enabled</span><span class="p">,</span> <span class="n">is_mamba2</span>
<span class="p">])</span>
<span class="n">selective_scan_plug</span> <span class="o">=</span> <span class="n">selective_scan_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
<span class="s2">&quot;selective_scan&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">input</span><span class="p">,</span> <span class="n">state_or_ptr</span><span class="p">,</span> <span class="n">delta</span><span class="p">,</span> <span class="n">delta_bias</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">BC</span><span class="p">,</span> <span class="n">D</span><span class="p">,</span> <span class="n">host_request_types</span><span class="p">,</span>
<span class="n">last_token_ids</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">host_context_lengths</span><span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
<span class="k">if</span> <span class="n">z</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">z</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">selective_scan_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">selective_scan_plg_creator</span><span class="p">,</span> <span class="s2">&quot;selective_scan&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
<div class="viewcode-block" id="rg_lru">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.rg_lru">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">rg_lru</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">A</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">state_or_ptr</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">block_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="n">y</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">y_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate_x</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate_x_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate_a</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">gate_a_bias</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">slot_mapping</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor (On GPU)</span>
<span class="sd"> The input tensor. Its shape is [batch_size, seq_len, dim]</span>
<span class="sd"> A : Tensor (On GPU)</span>
<span class="sd"> A matrix. Its shape is [dim]</span>
<span class="sd"> state_or_ptr : Tensor (On GPU or CPU)</span>
<span class="sd"> The lru state tensor. Its shape is [batch_size, dstate, dim]</span>
<span class="sd"> Or the CPU tensor of shape [1] for the pointer of paged states.</span>
<span class="sd"> host_request_types : Tensor (On CPU)</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/source/advanced/gpt-attention.md,</span>
<span class="sd"> last_token_ids : Tensor (On GPU)</span>
<span class="sd"> The inclusive prefix-sum of the lengths or the lengths of the</span>
<span class="sd"> sequences in the batch.</span>
<span class="sd"> dim : int</span>
<span class="sd"> The inner dimension of RG_LRU block</span>
<span class="sd"> block_size : int</span>
<span class="sd"> The block size of the block diagonal linear layer. It is used to</span>
<span class="sd"> support the cases that enable fused gate.</span>
<span class="sd"> dtype: str</span>
<span class="sd"> data type</span>
<span class="sd"> y : Tensor (On GPU) (Optional)</span>
<span class="sd"> The y tensor. Its shape is [batch_size, seq_len, dim]</span>
<span class="sd"> y_bias : Tensor (On GPU) (Optional)</span>
<span class="sd"> The y_bias tensor. Its shape is [dim]. If y_bias is not None, we</span>
<span class="sd"> will fuse GELU(y + y_bias) in this function.</span>
<span class="sd"> gate : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate tensor. Its shape is [batch_size, seq_len, 2 * dim].</span>
<span class="sd"> If gate is not None, we will fuse the gate_x and gate_a, otherwise</span>
<span class="sd"> use those two tensors.</span>
<span class="sd"> gate_bias : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate_bias tensor. Its shape is [2 * block_num, dim // block_num].</span>
<span class="sd"> If gate_bias is not None, we will fuse the bias add in this function.</span>
<span class="sd"> gate_x : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate_x tensor. Its shape is [batch_size, seq_len, dim]</span>
<span class="sd"> gate_x_bias : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate_x_bias tensor. Its shape is [block_num, dim // block_num].</span>
<span class="sd"> If gate_x_bias is not None, we will fuse the bias add in this function.</span>
<span class="sd"> gate_a : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate_a tensor. Its shape is [batch_size, seq_len, dim]</span>
<span class="sd"> gate_a_bias : Tensor (On GPU) (Optional)</span>
<span class="sd"> The gate_a_bias tensor. Its shape is [block_num, dim // block_num].</span>
<span class="sd"> If gate_a_bias is not None, we will fuse the bias add in this function.</span>
<span class="sd"> slot_mapping: Tensor (On GPU) (Optional)</span>
<span class="sd"> Real page index in state. Its shape is [dim], used for paged state, each page shape is [dstate, dim]</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="n">host_request_types</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">lru_plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;LRU&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">lru_plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span> <span class="o">==</span> <span class="p">(</span><span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">)</span>
<span class="n">enable_fuse_gate</span> <span class="o">=</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">has_gate_bias</span> <span class="o">=</span> <span class="p">(</span><span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">gate</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">assert</span> <span class="n">block_size</span> <span class="o">&gt;</span> <span class="mi">0</span>
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">gate_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">gate_x</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">gate_x_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">gate_a_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;dim&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">block_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;block_size&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">remove_input_padding</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;remove_input_padding&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">paged_state</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;paged_state&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">(</span><span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">),</span>
<span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;y_enabled&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">y_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;y_enabled&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;y_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">y_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;y_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;fuse_gate_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">fuse_gate_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;fuse_gate_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;gate_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">gate_bias_enabled</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;gate_bias_enabled&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span>
<span class="n">dim</span><span class="p">,</span> <span class="n">block_size</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">,</span> <span class="n">remove_input_padding</span><span class="p">,</span> <span class="n">paged_state</span><span class="p">,</span> <span class="n">y_enabled</span><span class="p">,</span>
<span class="n">y_bias_enabled</span><span class="p">,</span> <span class="n">fuse_gate_enabled</span><span class="p">,</span> <span class="n">gate_bias_enabled</span>
<span class="p">])</span>
<span class="n">lru_plug</span> <span class="o">=</span> <span class="n">lru_plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;rg_lru&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="nb">input</span><span class="p">,</span>
<span class="n">A</span><span class="p">,</span>
<span class="n">state_or_ptr</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">,</span>
<span class="n">last_token_ids</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">slot_mapping</span><span class="p">]</span>
<span class="k">if</span> <span class="n">y</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y</span><span class="p">]</span>
<span class="k">if</span> <span class="n">y_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">y_bias</span><span class="p">]</span>
<span class="k">if</span> <span class="n">enable_fuse_gate</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate</span><span class="p">]</span>
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_bias</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x</span><span class="p">,</span> <span class="n">gate_a</span><span class="p">]</span>
<span class="k">if</span> <span class="n">has_gate_bias</span><span class="p">:</span>
<span class="n">plug_inputs</span> <span class="o">+=</span> <span class="p">[</span><span class="n">gate_x_bias</span><span class="p">,</span> <span class="n">gate_a_bias</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">lru_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">lru_plg_creator</span><span class="p">,</span> <span class="s2">&quot;rg_lru&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">if</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">paged_state</span><span class="p">:</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">present_state</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">present_state</span></div>
<div class="viewcode-block" id="topk">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.topk">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">topk</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">k</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">largest</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">prefer_plugin</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an topk operation.</span>
<span class="sd"> As explained in the ONNX documentation,</span>
<span class="sd"> https://github.com/onnx/onnx/blob/main/docs/Operators.md#topk</span>
<span class="sd"> NOTE: One distinction from the ONNX topk op, the output is always sorted</span>
<span class="sd"> with TensorRT layer.</span>
<span class="sd"> Retrieve the top-K largest elements along a specified axis.</span>
<span class="sd"> Given an input tensor of shape [a_1, a_2, ..., a_n, r]</span>
<span class="sd"> and integer argument k, return two outputs:</span>
<span class="sd"> Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the values of the top k elements along the specified axis</span>
<span class="sd"> Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which contains the indices of the top k elements (original indices from the input tensor).</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor.</span>
<span class="sd"> k : int</span>
<span class="sd"> A single positive value corresponding to the number of top elements to retrieve</span>
<span class="sd"> dim: int</span>
<span class="sd"> The dimension in which to compute the topk indices.</span>
<span class="sd"> largest: bool</span>
<span class="sd"> Controls whether to return largest or smallest elements</span>
<span class="sd"> prefer_plugin : bool</span>
<span class="sd"> Whether to use the topkLastDim plugin if dim is last dim and k is static.</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensors (values, indices) produced by this topk operation.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">dim</span> <span class="o">=</span> <span class="n">dim_resolve_negative</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="n">prefer_plugin</span> <span class="ow">and</span> <span class="n">dim</span> <span class="o">==</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">last_dim</span> <span class="o">==</span> <span class="o">-</span><span class="mi">1</span><span class="p">:</span> <span class="c1"># dynamic?</span>
<span class="n">last_dim</span> <span class="o">=</span> <span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># since we might need to flatten the input to 2d tensor,</span>
<span class="c1"># we need to prepare the output shape</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">out_shape</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">shape</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span>
<span class="n">out_shape</span> <span class="o">=</span> <span class="n">concat</span><span class="p">(</span><span class="n">out_shape</span> <span class="o">+</span> <span class="p">[</span><span class="n">k</span><span class="p">])</span>
<span class="k">if</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="n">unsqueeze</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span>
<span class="mi">0</span><span class="p">)</span> <span class="c1"># special handling of rank-1 dynamic tensor</span>
<span class="k">elif</span> <span class="nb">input</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">concat</span><span class="p">([</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">last_dim</span><span class="p">]),</span>
<span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">input_2d</span> <span class="o">=</span> <span class="nb">input</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s2">&quot;TopkLastDim&quot;</span><span class="p">,</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">is_largest</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;is_largest&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;k&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_2d</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">is_largest</span><span class="p">])</span>
<span class="n">topk_last_dim_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;topk_last_dim&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_2d</span><span class="p">]</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">i</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">plug_inputs</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">topk_last_dim_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;topk_last_dim&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">values</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">indices</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">out_shape</span><span class="p">,</span> <span class="n">zero_is_placeholder</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># non-plugin path</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">dim_to_trt_axes</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_topk</span><span class="p">(</span>
<span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MAX</span> <span class="k">if</span> <span class="n">largest</span> <span class="k">else</span> <span class="n">trt</span><span class="o">.</span><span class="n">TopKOperation</span><span class="o">.</span><span class="n">MIN</span><span class="p">,</span>
<span class="n">k</span><span class="o">=</span><span class="n">k</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">axes</span><span class="o">=</span><span class="n">axes</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="k">if</span> <span class="n">k</span><span class="o">.</span><span class="n">ndim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">squeeze</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">layer</span><span class="o">.</span><span class="n">set_input</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">k</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">)</span>
<span class="n">values</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">values</span><span class="p">,</span> <span class="n">indices</span></div>
<div class="viewcode-block" id="scatter_nd">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.scatter_nd">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">scatter_nd</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">source</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Scatter_nd is a tensor operation that writes or updates values in a tensor based on indices.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input: Tensor</span>
<span class="sd"> The input tensor to be updated</span>
<span class="sd"> mask: Tensor</span>
<span class="sd"> A tensor of indices specifying the locations in data to be updated.</span>
<span class="sd"> source: Tensor</span>
<span class="sd"> A tensor of values to be written or scattered into data.</span>
<span class="sd"> Returns:</span>
<span class="sd"> New tensor with the same shape as the input tensor data,</span>
<span class="sd"> where the values from the source tensor are scattered or written into the output tensor</span>
<span class="sd"> at the locations specified by the mask tensor.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">scatter_layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_scatter</span><span class="p">(</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mask</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">source</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">mode</span><span class="o">=</span><span class="n">trt</span><span class="o">.</span><span class="n">ScatterMode</span><span class="o">.</span><span class="n">ND</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">scatter_layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">scatter_layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="low_latency_gemm">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">low_latency_gemm</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">mat2</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">alpha</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">strict_dtype</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;Low Latency GEMM is only support with plugin&quot;</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_plugin</span> <span class="o">!=</span> <span class="s2">&quot;fp8&quot;</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;Low Latency GEMM plugin only support fp8&quot;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s2">&quot;LowLatencyGemm&quot;</span><span class="p">,</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">if</span> <span class="p">((</span><span class="nb">input</span><span class="o">.</span><span class="n">dtype</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)</span> <span class="ow">or</span> <span class="p">((</span><span class="n">mat2</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">!=</span> <span class="n">trt</span><span class="o">.</span><span class="n">fp8</span><span class="p">)):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;Low Latency GEMM only support fp8 input&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="p">(</span><span class="n">alpha</span><span class="p">):</span>
<span class="k">assert</span> <span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">alpha</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">)</span> <span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
<span class="ow">and</span> <span class="n">alpha</span><span class="o">.</span><span class="n">size</span>
<span class="o">==</span> <span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;`alpha` must be passed as a float32 ndarray&quot;</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span> <span class="k">if</span> <span class="n">alpha</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;alpha&quot;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">.</span><span class="n">flatten</span><span class="p">(),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="k">if</span> <span class="n">strict_dtype</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">strict_dtype</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">)</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">strict_dtype</span>
<span class="k">if</span> <span class="p">(</span><span class="n">p_dtype</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">bfloat16</span><span class="p">]):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;strict_dtype must be float32, float16 or bfloat16 in low latency gemm plugin&quot;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;need to use strict dtype in low latency gemm plugin fp8&quot;</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">alpha</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">low_latency_gemm_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
<span class="s2">&quot;low_latency_gemm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">mat2</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
<span class="n">low_latency_gemm_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;low_latency_gemm&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="SideStreamIDType">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.SideStreamIDType">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">SideStreamIDType</span><span class="p">(</span><span class="n">IntEnum</span><span class="p">):</span>
<span class="n">disable</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">moe</span> <span class="o">=</span> <span class="mi">1</span></div>
<div class="viewcode-block" id="low_latency_gemm_swiglu">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.low_latency_gemm_swiglu">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">low_latency_gemm_swiglu</span><span class="p">(</span><span class="nb">input</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">weight</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">scale_d0</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_d1</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">scale_output</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add a matrix multiplication, followed by SwiGLU (`x * SiLU(gate)`) operation.</span>
<span class="sd"> The second SwiGLU operation takes the preceding tensor, splits it into two halves</span>
<span class="sd"> along the last dimension, applies SiLU to the second half and multiply the results. The</span>
<span class="sd"> behaviour is undefined if the last dimension is not even.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The first tensor (often called A).</span>
<span class="sd"> weight : Tensor</span>
<span class="sd"> The second tensor (often called B).</span>
<span class="sd"> scale_d0 : float</span>
<span class="sd"> The scale for dequantizing x, used for fp8</span>
<span class="sd"> scale_d1 : float</span>
<span class="sd"> The scale for dequantizing gate, used for fp8</span>
<span class="sd"> scale_output : float</span>
<span class="sd"> The scale for quantizing output, used for fp8</span>
<span class="sd"> Returns:</span>
<span class="sd"> The tensor produced by the inserted layer.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s1">&#39;LowLatencyGemmSwiglu&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_dtype</span> <span class="o">=</span> <span class="n">default_net</span><span class="p">()</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">p_dtype</span><span class="p">))],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_scale_d0</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_d0&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pf_scale_d1</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_d1&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_d1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pf_scale_output</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;scale_output&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scale_output</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">FLOAT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">(</span>
<span class="p">[</span><span class="n">pf_type</span><span class="p">,</span> <span class="n">pf_scale_output</span><span class="p">,</span> <span class="n">pf_scale_d0</span><span class="p">,</span> <span class="n">pf_scale_d1</span><span class="p">])</span>
<span class="n">low_latency_gemm_swiglu_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span>
<span class="s2">&quot;low_latency_gemm_swiglu&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">weight</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span>
<span class="n">low_latency_gemm_swiglu_plug</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
<div class="viewcode-block" id="cuda_stream_sync">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cuda_stream_sync">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cuda_stream_sync</span><span class="p">(</span><span class="n">input_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tensor</span><span class="p">],</span>
<span class="n">side_stream_id</span><span class="p">:</span> <span class="n">SideStreamIDType</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Wait for the side stream on the main stream.</span>
<span class="sd"> output = input_list[0]</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input_list : List[Tensor] (On GPU)</span>
<span class="sd"> The list of input tensors.</span>
<span class="sd"> side_stream_id : int (On CPU)</span>
<span class="sd"> The side stream ID.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_plugin_creator</span><span class="p">(</span>
<span class="s2">&quot;CudaStream&quot;</span><span class="p">,</span> <span class="s2">&quot;1&quot;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">p_side_stream_id</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;side_stream_id&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">side_stream_id</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">p_num_inputs</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;num_inputs&quot;</span><span class="p">,</span>
<span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_list</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pf_type</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span>
<span class="s2">&quot;type_id&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">input_list</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">p_side_stream_id</span><span class="p">,</span> <span class="n">p_num_inputs</span><span class="p">,</span> <span class="n">pf_type</span><span class="p">])</span>
<span class="n">plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;cuda_stream&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span><span class="nb">input</span><span class="o">.</span><span class="n">trt_tensor</span> <span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">input_list</span><span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v2</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="n">plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;cuda_stream&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="cp_split_plugin">
<a class="viewcode-back" href="../../legacy/python-api/tensorrt_llm.functional.html#tensorrt_llm.functional.cp_split_plugin">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">cp_split_plugin</span><span class="p">(</span>
<span class="n">input_ids</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_request_types</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="c1"># for pad-free input mode</span>
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">cp_rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> Add an operation to perform splitting for context parallelism.</span>
<span class="sd"> This operation split the input_ids into cp_size chunks, and return the cp_rank-th</span>
<span class="sd"> chunk.</span>
<span class="sd"> When the seqlen % cp_size != 0, the chunk sizes of each rank would be</span>
<span class="sd"> [seqlen // cp_size, seqlen // cp_size, ..., seqlen - (seqlen // cp_size) * cp_size]</span>
<span class="sd"> It inserts a IPluginV3Layer.</span>
<span class="sd"> Parameters:</span>
<span class="sd"> input : Tensor</span>
<span class="sd"> The input tensor contains the indices to split.</span>
<span class="sd"> host_request_types: Tensor = None (On CPU)</span>
<span class="sd"> The tensor on the host that indicates if a request is in context or</span>
<span class="sd"> generation phase. Its shape is [batch_size]. See Inflight Batching</span>
<span class="sd"> in docs/gpt_attention.md,</span>
<span class="sd"> host_context_lengths: Tensor = None (On CPU)</span>
<span class="sd"> A host tensor that contains the lengths of the different inputs</span>
<span class="sd"> Returns:</span>
<span class="sd"> The output split tensor.</span>
<span class="sd"> The length of the output split tensor.</span>
<span class="sd"> The index for rebuilding the sequence</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">plg_creator</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">get_plugin_registry</span><span class="p">()</span><span class="o">.</span><span class="n">get_creator</span><span class="p">(</span>
<span class="s1">&#39;CpSplit&#39;</span><span class="p">,</span> <span class="s1">&#39;1&#39;</span><span class="p">,</span> <span class="n">TRT_LLM_PLUGIN_NAMESPACE</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">plg_creator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_size&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_size</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">cp_rank</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginField</span><span class="p">(</span><span class="s2">&quot;cp_rank&quot;</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_rank</span><span class="p">)],</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span>
<span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldType</span><span class="o">.</span><span class="n">INT32</span><span class="p">)</span>
<span class="n">pfc</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">PluginFieldCollection</span><span class="p">([</span><span class="n">cp_size</span><span class="p">,</span> <span class="n">cp_rank</span><span class="p">])</span>
<span class="n">cp_split_plug</span> <span class="o">=</span> <span class="n">plg_creator</span><span class="o">.</span><span class="n">create_plugin</span><span class="p">(</span><span class="s2">&quot;cp_split&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">TensorRTPhase</span><span class="o">.</span><span class="n">BUILD</span><span class="p">)</span>
<span class="n">plug_inputs</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">input_ids</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span> <span class="n">host_request_types</span><span class="o">.</span><span class="n">trt_tensor</span><span class="p">,</span>
<span class="n">host_context_lengths</span><span class="o">.</span><span class="n">trt_tensor</span>
<span class="p">]</span>
<span class="n">layer</span> <span class="o">=</span> <span class="n">default_trtnet</span><span class="p">()</span><span class="o">.</span><span class="n">add_plugin_v3</span><span class="p">(</span><span class="n">plug_inputs</span><span class="p">,</span> <span class="p">[],</span> <span class="n">cp_split_plug</span><span class="p">)</span>
<span class="n">_add_plugin_info</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">plg_creator</span><span class="p">,</span> <span class="s2">&quot;cp_split&quot;</span><span class="p">,</span> <span class="n">pfc</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
<span class="n">layer</span><span class="p">),</span> <span class="n">_create_tensor</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">get_output</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layer</span><span class="p">)</span></div>
</pre></div>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
</div>
</footer>
</div>
<div class="bd-sidebar-secondary"></div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item">
<a class="footer-brand logo" href="https://www.nvidia.com">
<img src="../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
<img src="../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
</a></div>
<div class="footer-item">
<div class="footer-links">
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Your Privacy Choices</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
</div>
</div>
<div class="footer-item">
<p class="copyright">
Copyright © 2025, NVidia.
<br/>
</p>
</div>
<div class="footer-item">
<div class="extra_footer">
<p>Last updated on January 04, 2026.</p>
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/a65b0d4">a65b0d4</a>.</p>
</div></div>
</div>
</div>
</footer>
</body>
</html>