TensorRT-LLMs/_modules/tensorrt_llm/builder.html
2025-09-04 03:19:11 +00:00

2069 lines
276 KiB
HTML
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en" data-content_root="../../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.builder &#8212; TensorRT-LLM</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
.pst-js-only { display: none !important; }
</style>
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=8f2a1f02" />
<link rel="stylesheet" type="text/css" href="../../_static/styles/nvidia-sphinx-theme.css?v=df3ac72c" />
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css?v=76b2166b" />
<link rel="stylesheet" type="text/css" href="../../_static/autodoc_pydantic.css" />
<link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css?v=13237357" />
<link rel="stylesheet" type="text/css" href="../../_static/custom.css?v=95073da6" />
<!-- So that users can add custom icons -->
<script src="../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../_static/doctools.js?v=9a2dae69"></script>
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
<script>let toggleHintShow = 'Click to show';</script>
<script>let toggleHintHide = 'Click to hide';</script>
<script>let toggleOpenOnPrint = 'true';</script>
<script src="../../_static/togglebutton.js?v=4a39c7ea"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/builder';</script>
<script>
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.1.0rc3';
DOCUMENTATION_OPTIONS.show_version_warning_banner =
false;
</script>
<link rel="icon" href="../../_static/favicon.png"/>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="1.1.0rc3" />
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<dialog id="pst-search-dialog">
<form class="bd-search d-flex align-items-center"
action="../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form>
</dialog>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
<div class="bd-header__inner bd-page-width">
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
<span class="fa-solid fa-bars"></span>
</button>
<div class="col-lg-3 navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
<p class="title logo__title">TensorRT-LLM</p>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-2"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-2"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-2"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-2">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
</div>
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<dialog id="pst-primary-sidebar-modal"></dialog>
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT-LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT-LLM - Home"/>
<p class="title logo__title">TensorRT-LLM</p>
</a>
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-3"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-3"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-3"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-3">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<nav class="bd-docs-nav bd-links"
aria-label="Table of Contents">
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../key-features.html">Key Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../torch.html">PyTorch Backend</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../release-notes.html">Release Notes</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../installation/containers.html">Pre-built release container images on NGC</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
<li class="toctree-l1"><a class="reference internal" href="../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.html">Quick Start Recipe for Llama4 Scout 17B on TensorRT-LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.html">Quick Start Recipe for DeepSeek R1 on TensorRT-LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.html">Quick Start Recipe for Llama3.3 70B on TensorRT-LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.html">Quick Start Recipe for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">LLM API</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/index.html">LLM API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/reference.html">API Reference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Examples</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">LLM Examples Introduction</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="simple">
</ul>
</details></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/customization.html">LLM Common Customizations</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference.html">Generate text</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client.html">Curl Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_completion_client.html">Curl Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/genai_perf_client.html">Genai Perf Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/genai_perf_client_for_multimodal.html">Genai Perf Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/executor.html">Executor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-bench.html">trtllm-bench</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-build.html">trtllm-build</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../architecture/add-model.html">Adding a Model</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/executor.html">Executor API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/lora.html">Run gpt-2b + LoRA using Executor / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-management.html">KV Cache Management: Pools, Blocks, and Events</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/speculative-decoding.html">Speculative Sampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../advanced/disaggregated-service.html">Disaggregated-Service (Prototype)</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-benchmarking.html">Benchmarking</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../performance/performance-tuning-guide/index.html">Performance Tuning Guide</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/benchmarking-default-performance.html">Benchmarking Default Performance</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/useful-build-time-flags.html">Useful Build-Time Flags</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.html">Tuning Max Batch Size and Max Num Tokens</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/deciding-model-sharding-strategy.html">Deciding Model Sharding Strategy</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/fp8-quantization.html">FP8 Quantization</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../performance/performance-tuning-guide/useful-runtime-flags.html">Useful Runtime Options</a></li>
</ul>
</details></li>
<li class="toctree-l1"><a class="reference internal" href="../../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../reference/troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/precision.html">Numerical Precision</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/memory.html">Memory Usage of TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/ci-overview.html">Continuous Integration Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../reference/dev-containers.html">Using Dev Containers</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT-LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT-LLM server</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-GramSpeculativeDecodingin TensorRTLLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumb" class="d-print-none">
<ul class="bd-breadcrumbs">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.builder</span></li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<h1>Source code for tensorrt_llm.builder</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">copy</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">dataclasses</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">json</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">shutil</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">dataclasses</span><span class="w"> </span><span class="kn">import</span> <span class="n">dataclass</span><span class="p">,</span> <span class="n">field</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">functools</span><span class="w"> </span><span class="kn">import</span> <span class="n">cache</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._common</span><span class="w"> </span><span class="kn">import</span> <span class="n">_is_building</span><span class="p">,</span> <span class="n">check_max_num_tokens</span><span class="p">,</span> <span class="n">serialize_engine</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">get_sm_version</span><span class="p">,</span> <span class="n">np_bfloat16</span><span class="p">,</span> <span class="n">np_float8</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span>
<span class="n">to_json_file</span><span class="p">,</span> <span class="n">trt_gte</span><span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.auto_parallel</span><span class="w"> </span><span class="kn">import</span> <span class="n">auto_parallel</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.auto_parallel.config</span><span class="w"> </span><span class="kn">import</span> <span class="n">AutoParallelConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.bindings</span><span class="w"> </span><span class="kn">import</span> <span class="n">KVCacheType</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.functional</span><span class="w"> </span><span class="kn">import</span> <span class="n">PositionEmbeddingType</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.graph_rewriting</span><span class="w"> </span><span class="kn">import</span> <span class="n">optimize</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.logger</span><span class="w"> </span><span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.lora_helper</span><span class="w"> </span><span class="kn">import</span> <span class="n">LoraConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.models</span><span class="w"> </span><span class="kn">import</span> <span class="n">PretrainedConfig</span><span class="p">,</span> <span class="n">PretrainedModel</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.models.modeling_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">SpeculativeDecodingMode</span><span class="p">,</span> <span class="n">optimize_model</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.network</span><span class="w"> </span><span class="kn">import</span> <span class="n">Network</span><span class="p">,</span> <span class="n">net_guard</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">PluginConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantAlgo</span><span class="p">,</span> <span class="n">QuantMode</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.version</span><span class="w"> </span><span class="kn">import</span> <span class="n">__version__</span>
<span class="k">class</span><span class="w"> </span><span class="nc">ConfigEncoder</span><span class="p">(</span><span class="n">json</span><span class="o">.</span><span class="n">JSONEncoder</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">default</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obj</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">KVCacheType</span><span class="p">):</span>
<span class="c1"># For KVCacheType, convert it to string by split of &#39;KVCacheType.PAGED&#39;.</span>
<span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="fm">__str__</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">&#39;model_dump&#39;</span><span class="p">):</span>
<span class="c1"># Handle Pydantic models (including DecodingBaseConfig and subclasses)</span>
<span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">model_dump</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;json&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">default</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">BuilderConfig</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="c1"># intentionally use **kwargs, user should never call this ctor directly,</span>
<span class="c1"># use Builder.create_builder_config() instead</span>
<span class="k">pass</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trt_builder_config</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder_config</span> <span class="o">=</span> <span class="n">trt_builder_config</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">trt_builder_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IBuilderConfig</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder_config</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;return a dict with keys</span>
<span class="sd"> {</span>
<span class="sd"> &quot;builder_config&quot;: {</span>
<span class="sd"> # all key values set by the _init function</span>
<span class="sd"> },</span>
<span class="sd"> &quot;plugin_config&quot;: {</span>
<span class="sd"> # the network plugin_config (if any) attached to this BuilderConfig object</span>
<span class="sd"> # inside the Builder.build_engine</span>
<span class="sd"> },</span>
<span class="sd"> &quot;auto_parallel_config&quot;: {</span>
<span class="sd"> # the network auto_parallel_config (if any) attached to this BuilderConfig object</span>
<span class="sd"> # inside the Builder.build_engine</span>
<span class="sd"> }</span>
<span class="sd"> }</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;builder_config&#39;</span><span class="p">:</span> <span class="p">{}}</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s1">&#39;_trt_builder_config&#39;</span><span class="p">,</span> <span class="s1">&#39;plugin_config&#39;</span><span class="p">,</span>
<span class="s1">&#39;auto_parallel_config&#39;</span>
<span class="p">]:</span>
<span class="n">config</span><span class="p">[</span><span class="s1">&#39;builder_config&#39;</span><span class="p">][</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__getattribute__</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">&#39;plugin_config&#39;</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">PluginConfig</span><span class="p">),</span> \
<span class="sa">f</span><span class="s2">&quot;Found unexpected plugin_config object with type: </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">config</span><span class="p">[</span><span class="s1">&#39;plugin_config&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="k">return</span> <span class="n">config</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Builder</span><span class="p">():</span>
<span class="n">_ALLOWED_PRECISIONS</span> <span class="o">=</span> <span class="p">[</span>
<span class="s1">&#39;float32&#39;</span><span class="p">,</span> <span class="s1">&#39;float16&#39;</span><span class="p">,</span> <span class="s1">&#39;bfloat16&#39;</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">FLOAT</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">BF16</span>
<span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Builder</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="kc">True</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">trt_builder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">Builder</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_network</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Network</span><span class="p">:</span>
<span class="n">explicit_batch_flag</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Explicit batch flag will be deprecated in TRT 10</span>
<span class="k">if</span> <span class="s2">&quot;EXPLICIT_BATCH&quot;</span> <span class="ow">in</span> <span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">__members__</span><span class="o">.</span><span class="n">keys</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">explicit_batch_flag</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">EXPLICIT_BATCH</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">return</span> <span class="n">Network</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">(</span>
<span class="n">explicit_batch_flag</span>
<span class="o">|</span> <span class="p">(</span><span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">STRONGLY_TYPED</span><span class="p">))))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">Network</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">(</span><span class="n">explicit_batch_flag</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_builder_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">precision</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">],</span>
<span class="n">timing_cache</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tensor_parallel</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">use_refit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">int8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">strongly_typed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">profiling_verbosity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;layer_names_only&quot;</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">weight_streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">precision_constraints</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;obey&quot;</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BuilderConfig</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; @brief Create a builder config with given precisions and timing cache</span>
<span class="sd"> @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS</span>
<span class="sd"> @param timing_cache: a timing cache object or a path to a timing cache file</span>
<span class="sd"> @param tensor_parallel: number of GPUs used for tensor parallel</span>
<span class="sd"> @param kwargs: any other arguments users would like to attach to the config object as attributes</span>
<span class="sd"> @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others</span>
<span class="sd"> @param int8: whether to build with int8 enabled or not. Can&#39;t be used together with refit option</span>
<span class="sd"> @return: A BuilderConfig object, return None if failed</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="ow">and</span> <span class="n">strongly_typed</span>
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;quant_mode&quot;</span><span class="p">,</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">strongly_typed</span> <span class="ow">and</span> <span class="n">precision</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ALLOWED_PRECISIONS</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;precision should be one of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_ALLOWED_PRECISIONS</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_builder_config</span><span class="p">()</span>
<span class="k">if</span> <span class="n">weight_streaming</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">WEIGHT_STREAMING</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">fp8</span> <span class="o">=</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">()</span> <span class="ow">or</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">()</span>
<span class="k">if</span> <span class="n">precision</span> <span class="o">==</span> <span class="s1">&#39;float16&#39;</span> <span class="ow">or</span> <span class="n">precision</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">FP16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">precision</span> <span class="o">==</span> <span class="s1">&#39;bfloat16&#39;</span> <span class="ow">or</span> <span class="n">precision</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">BF16</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">BF16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">if</span> <span class="n">int8</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">fp8</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">FP8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_refit</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">REFIT</span><span class="p">)</span>
<span class="c1"># Use fine-grained refit when strip plan is enabled in TRT10.2+.</span>
<span class="k">if</span> <span class="n">use_strip_plan</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">REFIT_INDIVIDUAL</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_strip_plan</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">STRIP_PLAN</span><span class="p">)</span>
<span class="c1"># Set TRT Engine profiling verbosity</span>
<span class="k">if</span> <span class="n">profiling_verbosity</span> <span class="o">==</span> <span class="s2">&quot;detailed&quot;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">DETAILED</span>
<span class="k">elif</span> <span class="n">profiling_verbosity</span> <span class="o">==</span> <span class="s2">&quot;none&quot;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">NONE</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">LAYER_NAMES_ONLY</span>
<span class="c1"># set timing cache</span>
<span class="n">cache</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">timing_cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># use given cache</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">):</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">timing_cache</span>
<span class="c1"># read cache from file</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span>
<span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">create_timing_cache</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Invalid timing cache, using freshly created one&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cache</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">create_timing_cache</span><span class="p">(</span><span class="sa">b</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="c1"># When user does not given any existing cache, internally always created one</span>
<span class="c1"># so the cache should never None here</span>
<span class="k">assert</span> <span class="n">cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cache</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">)</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_timing_cache</span><span class="p">(</span><span class="n">cache</span><span class="p">,</span> <span class="n">ignore_mismatch</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># set weight sparsity</span>
<span class="n">weight_sparsity</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;weight_sparsity&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">weight_sparsity</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">SPARSE_WEIGHTS</span><span class="p">)</span>
<span class="c1"># TODO: remove this constraint after trt 10.6 is integrated</span>
<span class="k">if</span> <span class="n">trt_gte</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">6</span><span class="p">):</span>
<span class="c1"># set monitor memory</span>
<span class="n">monitor_memory</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;monitor_memory&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">monitor_memory</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">MONITOR_MEMORY</span><span class="p">)</span>
<span class="k">return</span> <span class="n">BuilderConfig</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span><span class="n">config</span><span class="p">,</span>
<span class="n">precision</span><span class="o">=</span><span class="n">precision</span><span class="p">,</span>
<span class="n">tensor_parallel</span><span class="o">=</span><span class="n">tensor_parallel</span><span class="p">,</span>
<span class="n">use_refit</span><span class="o">=</span><span class="n">use_refit</span><span class="p">,</span>
<span class="n">int8</span><span class="o">=</span><span class="n">int8</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="o">=</span><span class="n">force_num_profiles</span><span class="p">,</span>
<span class="n">strongly_typed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="o">=</span><span class="n">use_strip_plan</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_add_optimization_profile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span> <span class="n">BuilderConfig</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">input_tensors</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">_inputs</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensors</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;There are no inputs in the network!&quot;</span><span class="p">)</span>
<span class="k">return</span>
<span class="n">num_profiles</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">input_tensors</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span>
<span class="n">force_num_profiles</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span> <span class="s2">&quot;force_num_profiles&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_profiles</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Adding optimization profile </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s1">/</span><span class="si">{</span><span class="n">num_profiles</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">profile</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_optimization_profile</span><span class="p">()</span>
<span class="k">for</span> <span class="n">input_name</span> <span class="ow">in</span> <span class="n">input_tensors</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensors</span><span class="p">[</span><span class="n">input_name</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">shape_profile</span> <span class="o">=</span> <span class="n">input_tensors</span><span class="p">[</span><span class="n">input_name</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">min_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">min</span><span class="p">]</span>
<span class="n">opt_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">opt</span><span class="p">]</span>
<span class="n">max_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">max</span><span class="p">]</span>
<span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">_auto_parallel_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">io_shards</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">_auto_parallel_config</span><span class="p">[</span><span class="s2">&quot;io_shards&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">input_name</span> <span class="ow">in</span> <span class="n">io_shards</span><span class="p">:</span>
<span class="n">shards</span> <span class="o">=</span> <span class="n">io_shards</span><span class="p">[</span><span class="n">input_name</span><span class="p">]</span>
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">shard_num</span> <span class="ow">in</span> <span class="n">shards</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">min_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">min_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">/</span> <span class="n">shard_num</span><span class="p">))</span>
<span class="n">opt_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="nb">round</span><span class="p">(</span><span class="n">opt_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">/</span> <span class="n">shard_num</span><span class="p">))</span>
<span class="n">max_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">max_shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">/</span> <span class="n">shard_num</span><span class="p">))</span>
<span class="n">profile</span><span class="o">.</span><span class="n">set_shape</span><span class="p">(</span><span class="n">input_name</span><span class="p">,</span> <span class="n">min_shape</span><span class="p">,</span> <span class="n">opt_shape</span><span class="p">,</span> <span class="n">max_shape</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">input_name</span><span class="si">}</span><span class="s1">, min: </span><span class="si">{</span><span class="n">min_shape</span><span class="si">}</span><span class="s1">, opt: </span><span class="si">{</span><span class="n">opt_shape</span><span class="si">}</span><span class="s1">, max: </span><span class="si">{</span><span class="n">max_shape</span><span class="si">}</span><span class="s1">, dimension names: </span><span class="si">{</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">add_optimization_profile</span><span class="p">(</span>
<span class="n">profile</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Added optimization profile: #</span><span class="si">{</span><span class="n">ret</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">force_num_profiles</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span>
<span class="p">)</span> <span class="o">==</span> <span class="n">force_num_profiles</span> <span class="ow">and</span> <span class="n">force_num_profiles</span> <span class="o">&lt;</span> <span class="n">num_profiles</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Only adding </span><span class="si">{</span><span class="n">force_num_profiles</span><span class="si">}</span><span class="s2"> profiles instead of </span><span class="si">{</span><span class="n">num_profiles</span><span class="si">}</span><span class="s2">.&quot;</span>
<span class="p">)</span>
<span class="k">break</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_validate_named_dimensions</span><span class="p">(</span>
<span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span>
<span class="p">),</span> <span class="s2">&quot;Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_named_dimensions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> For each profile, validate that the named dimensions of different input tensors in this profile all have same range.</span>
<span class="sd"> TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT-LLM are correct and</span>
<span class="sd"> makes the error msg more user friendly.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">valid</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">profile_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span>
<span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="p">):</span>
<span class="n">dimension_to_range</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">input_name</span><span class="p">,</span> <span class="n">input_tensor</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">_inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="c1"># it&#39;s legal that a Tensor does not have dim_range?</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">profile</span> <span class="o">=</span> <span class="n">input_tensor</span><span class="o">.</span><span class="n">profiles</span><span class="p">[</span><span class="n">profile_idx</span><span class="p">]</span>
<span class="k">for</span> <span class="n">dim_idx</span><span class="p">,</span> <span class="n">dim_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">profile</span><span class="o">.</span><span class="n">dimension_names</span><span class="p">):</span>
<span class="k">if</span> <span class="n">dim_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dimension_to_range</span><span class="p">:</span>
<span class="n">dimension_to_range</span><span class="p">[</span><span class="n">dim_name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
<span class="nb">min</span><span class="p">,</span> <span class="n">opt</span><span class="p">,</span> <span class="nb">max</span> <span class="o">=</span> <span class="n">profile</span><span class="o">.</span><span class="n">min</span><span class="p">[</span><span class="n">dim_idx</span><span class="p">],</span> <span class="n">profile</span><span class="o">.</span><span class="n">opt</span><span class="p">[</span>
<span class="n">dim_idx</span><span class="p">],</span> <span class="n">profile</span><span class="o">.</span><span class="n">max</span><span class="p">[</span><span class="n">dim_idx</span><span class="p">]</span>
<span class="n">dimension_to_range</span><span class="p">[</span><span class="n">dim_name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="p">(</span><span class="n">input_name</span><span class="p">,</span> <span class="p">(</span><span class="nb">min</span><span class="p">,</span> <span class="n">opt</span><span class="p">,</span> <span class="nb">max</span><span class="p">)))</span>
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dimension_to_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">unique_ranges</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">r</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">ranges</span><span class="p">])</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Validating dimension:</span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2">, ranges for this dim are:</span><span class="si">{</span><span class="n">unique_ranges</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">unique_ranges</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Found illegal dimension setting for profile </span><span class="si">{</span><span class="n">profile_idx</span><span class="si">}</span><span class="s2">, dimension name is: </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s2">&quot;Offensive tensors which have this dimension are:</span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">+</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">r</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">r</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">ranges</span><span class="p">]))</span>
<span class="n">valid</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">return</span> <span class="n">valid</span>
<span class="nd">@_is_building</span>
<span class="k">def</span><span class="w"> </span><span class="nf">refit_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> @brief: Refit one TensorRT engine using weights from the network,</span>
<span class="sd"> user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.</span>
<span class="sd"> @param engine_buffer: A serialized TensorRT engine.</span>
<span class="sd"> @param network: Network object.</span>
<span class="sd"> @return: A serialized TRT engine if refit successfully, None otherwise</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Refit TRT engine&#39;</span><span class="p">)</span>
<span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># Refit engine</span>
<span class="n">refitter</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Refitter</span><span class="p">(</span><span class="n">engine</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">refitter</span><span class="o">.</span><span class="n">set_named_weights</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">()):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to refit weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s1">&#39;Please set named parameters before building multiple engines.&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">refitter</span><span class="o">.</span><span class="n">refit_cuda_engine</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s1">&#39;Failed to refit engine.&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Total time of refitting </span><span class="si">{</span><span class="n">engine</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">serialized_engine</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
<span class="k">return</span> <span class="n">serialized_engine</span>
<span class="nd">@_is_building</span>
<span class="k">def</span><span class="w"> </span><span class="nf">build_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span>
<span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> @brief: Build one TensorRT engine from the network.</span>
<span class="sd"> @param network: Network object.</span>
<span class="sd"> @param builder_config: BuilderConfig object.</span>
<span class="sd"> @return: A serialized TRT engine.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">builder_config</span><span class="o">.</span><span class="n">plugin_config</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span>
<span class="n">builder_config</span><span class="o">.</span><span class="n">auto_parallel_config</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">auto_parallel_config</span>
<span class="k">if</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_add_optimization_profile</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Total optimization profiles added: </span><span class="si">{</span><span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># Rename weights</span>
<span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">managed_parameters</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">is_managed</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;managed_weights should be provided when enabled&quot;</span>
<span class="n">managed_parameters</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">param</span><span class="o">.</span><span class="n">set_name</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">network</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">(</span><span class="n">network</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">is_buffer</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Parameter </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">param</span><span class="o">.</span><span class="n">raw_value</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">param</span><span class="o">.</span><span class="n">raw_value</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2"> was created&quot;</span>
<span class="s2">&quot; but unused in forward method, so not materialized to TRT network&quot;</span>
<span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">set_name</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">network</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to set weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># This mark_weights_refittable has no side effect when refit_individual is not enabled.</span>
<span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">mark_weights_refittable</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">network</span><span class="o">.</span><span class="n">_fill_weights</span><span class="p">()</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Total time to initialize the weights in network </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="c1"># Build engine</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Build TensorRT engine </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">build_serialized_network</span><span class="p">(</span>
<span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="p">,</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Engine building failed, please check the error log.&#39;</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Total time of building </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">managed_parameters</span><span class="p">:</span>
<span class="n">name</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">name</span>
<span class="n">value</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">_value</span>
<span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to get weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">need_transpose</span><span class="p">:</span>
<span class="c1"># MOE has ndim=3 and uses plugin, no need to transpose</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># WAR for bug 4641821</span>
<span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">engine</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save_timing_cache</span><span class="p">(</span><span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span> <span class="n">out_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Serialize timing cache of given builder config to file specified by out_path</span>
<span class="sd"> return True if the cache is successfully serialized, False otherwise</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">get_timing_cache</span><span class="p">()</span>
<span class="k">if</span> <span class="n">cache</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;No timing cache found in the given builder config, skip saving.&#39;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">with</span> <span class="n">cache</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span> <span class="k">as</span> <span class="n">buffer</span><span class="p">:</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">out_path</span><span class="p">,</span> <span class="s2">&quot;wb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">buffer</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
<span class="n">os</span><span class="o">.</span><span class="n">fsync</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Timing cache serialized to </span><span class="si">{</span><span class="n">out_path</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save_config</span><span class="p">(</span><span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span> <span class="n">config_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="n">to_json_file</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">config_path</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Config saved to </span><span class="si">{</span><span class="n">config_path</span><span class="si">}</span><span class="s1">.&#39;</span><span class="p">)</span>
<div class="viewcode-block" id="BuildConfig">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig">[docs]</a>
<span class="nd">@dataclass</span>
<span class="k">class</span><span class="w"> </span><span class="nc">BuildConfig</span><span class="p">:</span>
<span class="n">max_input_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span>
<span class="n">max_seq_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">opt_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">max_num_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8192</span>
<span class="n">opt_num_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">kv_cache_type</span><span class="p">:</span> <span class="n">KVCacheType</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">strongly_typed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">force_num_profiles</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">profiling_verbosity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;layer_names_only&#39;</span>
<span class="n">enable_debug_output</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">max_draft_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">speculative_decoding_mode</span><span class="p">:</span> <span class="n">SpeculativeDecodingMode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span>
<span class="n">use_refit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">input_timing_cache</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">output_timing_cache</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;model.cache&#39;</span>
<span class="n">lora_config</span><span class="p">:</span> <span class="n">LoraConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">LoraConfig</span><span class="p">)</span>
<span class="n">auto_parallel_config</span><span class="p">:</span> <span class="n">AutoParallelConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span>
<span class="n">default_factory</span><span class="o">=</span><span class="n">AutoParallelConfig</span><span class="p">)</span>
<span class="n">weight_sparsity</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">weight_streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">plugin_config</span><span class="p">:</span> <span class="n">PluginConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">PluginConfig</span><span class="p">)</span>
<span class="n">use_strip_plan</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">max_encoder_input_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span> <span class="c1"># for enc-dec DecoderModel</span>
<span class="n">dry_run</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">visualize_network</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">monitor_memory</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">use_mrope</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
<span class="c1"># Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),</span>
<span class="c1"># we need to handle it given model architecture.</span>
<div class="viewcode-block" id="BuildConfig.update_kv_cache_type">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.update_kv_cache_type">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update_kv_cache_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">paged_kv_cache_attr</span> <span class="o">=</span> <span class="s1">&#39;paged_state&#39;</span> <span class="k">if</span> <span class="n">model_architecture</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s1">&#39;MambaForCausalLM&#39;</span><span class="p">,</span> <span class="s1">&#39;RecurrentGemmaForCausalLM&#39;</span>
<span class="p">]</span> <span class="k">else</span> <span class="s1">&#39;paged_kv_cache&#39;</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">paged_kv_cache_val</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">paged_kv_cache_val</span> <span class="o">==</span> <span class="kc">True</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
<span class="n">paged_kv_cache_val</span> <span class="o">==</span> <span class="kc">False</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span> <span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="k">else</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">CONTINUOUS</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">getattr</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">override_attri</span><span class="p">(</span><span class="n">attr_name</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
<span class="n">val</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">!=</span> <span class="n">value</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Overriding </span><span class="si">{</span><span class="n">attr_name</span><span class="si">}</span><span class="s1"> to </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="c1"># Init other paged kvcache attri to false. For RecurrentGemma, we only support paged_state and paged_kv_cache have</span>
<span class="c1"># the same values. All other models should only consume either of the value and set other to False.</span>
<span class="n">is_recurrent_gemma</span> <span class="o">=</span> <span class="n">model_architecture</span> <span class="o">==</span> <span class="s1">&#39;RecurrentGemmaForCausalLM&#39;</span>
<span class="k">if</span> <span class="n">paged_kv_cache_attr</span> <span class="o">==</span> <span class="s1">&#39;paged_state&#39;</span><span class="p">:</span>
<span class="n">override_attri</span><span class="p">(</span>
<span class="s1">&#39;paged_kv_cache&#39;</span><span class="p">,</span>
<span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_recurrent_gemma</span> <span class="k">else</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">override_attri</span><span class="p">(</span><span class="s1">&#39;paged_state&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></div>
<div class="viewcode-block" id="BuildConfig.get_build_config_defaults">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.get_build_config_defaults">[docs]</a>
<span class="nd">@classmethod</span>
<span class="nd">@cache</span>
<span class="k">def</span><span class="w"> </span><span class="nf">get_build_config_defaults</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
<span class="k">return</span> <span class="p">{</span>
<span class="n">field</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">field</span><span class="o">.</span><span class="n">default</span>
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">fields</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span>
<span class="k">if</span> <span class="n">field</span><span class="o">.</span><span class="n">default</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">dataclasses</span><span class="o">.</span><span class="n">MISSING</span>
<span class="p">}</span></div>
<div class="viewcode-block" id="BuildConfig.from_dict">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.from_dict">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">,</span> <span class="n">plugin_config</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span>
<span class="n">config</span>
<span class="p">)</span> <span class="c1"># it just does not make sense to change the input arg `config`</span>
<span class="n">defaults</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_build_config_defaults</span><span class="p">()</span>
<span class="n">max_input_len</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_input_len&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_input_len&#39;</span><span class="p">))</span>
<span class="n">max_seq_len</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_seq_len&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_seq_len&#39;</span><span class="p">))</span>
<span class="n">max_batch_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_batch_size&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_batch_size&#39;</span><span class="p">))</span>
<span class="n">max_beam_width</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_beam_width&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_beam_width&#39;</span><span class="p">))</span>
<span class="n">max_num_tokens</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_num_tokens&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_num_tokens&#39;</span><span class="p">))</span>
<span class="n">opt_num_tokens</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;opt_num_tokens&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;opt_num_tokens&#39;</span><span class="p">))</span>
<span class="n">opt_batch_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;opt_batch_size&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;opt_batch_size&#39;</span><span class="p">))</span>
<span class="n">max_prompt_embedding_table_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;max_prompt_embedding_table_size&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_prompt_embedding_table_size&#39;</span><span class="p">))</span>
<span class="k">if</span> <span class="s2">&quot;kv_cache_type&quot;</span> <span class="ow">in</span> <span class="n">config</span> <span class="ow">and</span> <span class="n">config</span><span class="p">[</span><span class="s2">&quot;kv_cache_type&quot;</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">from_string</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;kv_cache_type&#39;</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">kv_cache_type</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">gather_context_logits</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;gather_context_logits&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;gather_context_logits&#39;</span><span class="p">))</span>
<span class="n">gather_generation_logits</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;gather_generation_logits&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;gather_generation_logits&#39;</span><span class="p">))</span>
<span class="n">strongly_typed</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;strongly_typed&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;strongly_typed&#39;</span><span class="p">))</span>
<span class="n">force_num_profiles</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;force_num_profiles&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;force_num_profiles&#39;</span><span class="p">))</span>
<span class="n">weight_sparsity</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;weight_sparsity&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;weight_sparsity&#39;</span><span class="p">))</span>
<span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;profiling_verbosity&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;profiling_verbosity&#39;</span><span class="p">))</span>
<span class="n">enable_debug_output</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;enable_debug_output&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;enable_debug_output&#39;</span><span class="p">))</span>
<span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;max_draft_len&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_draft_len&#39;</span><span class="p">))</span>
<span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;speculative_decoding_mode&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;speculative_decoding_mode&#39;</span><span class="p">))</span>
<span class="n">use_refit</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;use_refit&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;use_refit&#39;</span><span class="p">))</span>
<span class="n">input_timing_cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;input_timing_cache&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;input_timing_cache&#39;</span><span class="p">))</span>
<span class="n">output_timing_cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;output_timing_cache&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;output_timing_cache&#39;</span><span class="p">))</span>
<span class="n">lora_config</span> <span class="o">=</span> <span class="n">LoraConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;lora_config&#39;</span><span class="p">,</span> <span class="p">{}))</span>
<span class="n">auto_parallel_config</span> <span class="o">=</span> <span class="n">AutoParallelConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span>
<span class="n">config</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;auto_parallel_config&#39;</span><span class="p">,</span> <span class="p">{}))</span>
<span class="n">max_encoder_input_len</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;max_encoder_input_len&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;max_encoder_input_len&#39;</span><span class="p">))</span>
<span class="n">weight_streaming</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;weight_streaming&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;weight_streaming&#39;</span><span class="p">))</span>
<span class="n">use_strip_plan</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;use_strip_plan&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;use_strip_plan&#39;</span><span class="p">))</span>
<span class="k">if</span> <span class="n">plugin_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">plugin_config</span> <span class="o">=</span> <span class="n">PluginConfig</span><span class="p">()</span>
<span class="k">if</span> <span class="s2">&quot;plugin_config&quot;</span> <span class="ow">in</span> <span class="n">config</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="n">plugin_config</span><span class="o">.</span><span class="n">update_from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">&quot;plugin_config&quot;</span><span class="p">])</span>
<span class="n">dry_run</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;dry_run&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;dry_run&#39;</span><span class="p">))</span>
<span class="n">visualize_network</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;visualize_network&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;visualize_network&#39;</span><span class="p">))</span>
<span class="n">monitor_memory</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;monitor_memory&#39;</span><span class="p">,</span>
<span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;monitor_memory&#39;</span><span class="p">))</span>
<span class="n">use_mrope</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;use_mrope&#39;</span><span class="p">,</span> <span class="n">defaults</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;use_mrope&#39;</span><span class="p">))</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">max_input_len</span><span class="o">=</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="n">max_seq_len</span><span class="o">=</span><span class="n">max_seq_len</span><span class="p">,</span>
<span class="n">max_batch_size</span><span class="o">=</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">max_beam_width</span><span class="o">=</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="n">max_num_tokens</span><span class="o">=</span><span class="n">max_num_tokens</span><span class="p">,</span>
<span class="n">opt_num_tokens</span><span class="o">=</span><span class="n">opt_num_tokens</span><span class="p">,</span>
<span class="n">opt_batch_size</span><span class="o">=</span><span class="n">opt_batch_size</span><span class="p">,</span>
<span class="n">max_prompt_embedding_table_size</span><span class="o">=</span><span class="n">max_prompt_embedding_table_size</span><span class="p">,</span>
<span class="n">kv_cache_type</span><span class="o">=</span><span class="n">kv_cache_type</span><span class="p">,</span>
<span class="n">gather_context_logits</span><span class="o">=</span><span class="n">gather_context_logits</span><span class="p">,</span>
<span class="n">gather_generation_logits</span><span class="o">=</span><span class="n">gather_generation_logits</span><span class="p">,</span>
<span class="n">strongly_typed</span><span class="o">=</span><span class="n">strongly_typed</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="o">=</span><span class="n">force_num_profiles</span><span class="p">,</span>
<span class="n">profiling_verbosity</span><span class="o">=</span><span class="n">profiling_verbosity</span><span class="p">,</span>
<span class="n">enable_debug_output</span><span class="o">=</span><span class="n">enable_debug_output</span><span class="p">,</span>
<span class="n">max_draft_len</span><span class="o">=</span><span class="n">max_draft_len</span><span class="p">,</span>
<span class="n">speculative_decoding_mode</span><span class="o">=</span><span class="n">speculative_decoding_mode</span><span class="p">,</span>
<span class="n">use_refit</span><span class="o">=</span><span class="n">use_refit</span><span class="p">,</span>
<span class="n">input_timing_cache</span><span class="o">=</span><span class="n">input_timing_cache</span><span class="p">,</span>
<span class="n">output_timing_cache</span><span class="o">=</span><span class="n">output_timing_cache</span><span class="p">,</span>
<span class="n">lora_config</span><span class="o">=</span><span class="n">lora_config</span><span class="p">,</span>
<span class="n">auto_parallel_config</span><span class="o">=</span><span class="n">auto_parallel_config</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="o">=</span><span class="n">use_strip_plan</span><span class="p">,</span>
<span class="n">max_encoder_input_len</span><span class="o">=</span><span class="n">max_encoder_input_len</span><span class="p">,</span>
<span class="n">weight_sparsity</span><span class="o">=</span><span class="n">weight_sparsity</span><span class="p">,</span>
<span class="n">weight_streaming</span><span class="o">=</span><span class="n">weight_streaming</span><span class="p">,</span>
<span class="n">plugin_config</span><span class="o">=</span><span class="n">plugin_config</span><span class="p">,</span>
<span class="n">dry_run</span><span class="o">=</span><span class="n">dry_run</span><span class="p">,</span>
<span class="n">visualize_network</span><span class="o">=</span><span class="n">visualize_network</span><span class="p">,</span>
<span class="n">monitor_memory</span><span class="o">=</span><span class="n">monitor_memory</span><span class="p">,</span>
<span class="n">use_mrope</span><span class="o">=</span><span class="n">use_mrope</span><span class="p">)</span></div>
<div class="viewcode-block" id="BuildConfig.from_json_file">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.from_json_file">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_file</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_file</span><span class="p">,</span> <span class="n">plugin_config</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="k">return</span> <span class="n">BuildConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">plugin_config</span><span class="o">=</span><span class="n">plugin_config</span><span class="p">)</span></div>
<div class="viewcode-block" id="BuildConfig.to_dict">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.to_dict">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
<span class="c1"># the enum KVCacheType cannot be converted automatically</span>
<span class="k">if</span> <span class="n">output</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;kv_cache_type&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">output</span><span class="p">[</span><span class="s1">&#39;kv_cache_type&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">output</span><span class="p">[</span><span class="s1">&#39;kv_cache_type&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="n">output</span><span class="p">[</span><span class="s1">&#39;plugin_config&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="s1">&#39;plugin_config&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="n">output</span><span class="p">[</span><span class="s1">&#39;lora_config&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="s1">&#39;lora_config&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="n">output</span><span class="p">[</span><span class="s1">&#39;auto_parallel_config&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="s1">&#39;auto_parallel_config&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">to_dict</span><span class="p">(</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">output</span></div>
<div class="viewcode-block" id="BuildConfig.update_from_dict">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.update_from_dict">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update_from_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">config</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2"> object has no attribute </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="BuildConfig.update">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.update">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">update_from_dict</span><span class="p">(</span><span class="n">kwargs</span><span class="p">)</span></div>
</div>
<span class="k">class</span><span class="w"> </span><span class="nc">EngineConfig</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pretrained_config</span><span class="p">:</span> <span class="s1">&#39;PretrainedConfig&#39;</span><span class="p">,</span>
<span class="n">build_config</span><span class="p">:</span> <span class="s1">&#39;BuildConfig&#39;</span><span class="p">,</span> <span class="n">version</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pretrained_config</span> <span class="o">=</span> <span class="n">pretrained_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="o">=</span> <span class="n">build_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">version</span> <span class="o">=</span> <span class="n">version</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_file</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_file</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">from_json_str</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_str</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_str</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">config_str</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">PretrainedConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">]),</span>
<span class="n">BuildConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s1">&#39;build_config&#39;</span><span class="p">]),</span>
<span class="n">config</span><span class="p">[</span><span class="s1">&#39;version&#39;</span><span class="p">])</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">build_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;dry_run&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="c1"># Not an Engine Characteristic</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;visualize_network&#39;</span><span class="p">,</span>
<span class="kc">None</span><span class="p">)</span> <span class="c1"># Not an Engine Characteristic</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s1">&#39;version&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">version</span><span class="p">,</span>
<span class="s1">&#39;pretrained_config&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">(),</span>
<span class="s1">&#39;build_config&#39;</span><span class="p">:</span> <span class="n">build_config</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Engine</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">config</span><span class="p">:</span> <span class="n">EngineConfig</span><span class="p">,</span>
<span class="n">engine</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
<span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="p">{},</span>
<span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="n">engine</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="o">=</span> <span class="n">managed_weights</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">value</span><span class="o">.</span><span class="n">flags</span><span class="p">[</span><span class="s1">&#39;C_CONTIGUOUS&#39;</span><span class="p">]:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">lora_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span>
<span class="n">lora_dirs</span> <span class="o">=</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span>
<span class="n">root_lora_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;lora&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">lora_dirs</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">lora_dir</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">lora_dirs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_ckpt_source</span> <span class="o">==</span> <span class="s1">&#39;hf&#39;</span><span class="p">:</span>
<span class="n">target_lora_dir</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">root_lora_dir</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">target_lora_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="s1">&#39;adapter_config.json&#39;</span><span class="p">),</span>
<span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">weight_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="s1">&#39;adapter_model.bin&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">weight_file</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">weight_file</span><span class="p">,</span> <span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">weight_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span>
<span class="s1">&#39;adapter_model.safetensors&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">weight_file</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">weight_file</span><span class="p">,</span> <span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;lora/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">elif</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_ckpt_source</span> <span class="o">==</span> <span class="s1">&#39;nemo&#39;</span><span class="p">:</span>
<span class="n">target_lora_file</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">root_lora_dir</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">.nemo&quot;</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copyfile</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="n">target_lora_file</span><span class="p">)</span>
<span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;lora/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">.nemo&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">rmtree</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">MIXED_PRECISION</span><span class="p">:</span>
<span class="n">quant_dict</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;version&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">version</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">quant_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="n">config_dict</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">][</span><span class="s1">&#39;quantization&#39;</span><span class="p">])</span>
<span class="n">config_dict</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">][</span><span class="s1">&#39;quantization&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;quantized_layers&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;quant_cfg.json&#39;</span><span class="p">),</span>
<span class="s2">&quot;w&quot;</span><span class="p">,</span>
<span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">quant_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="bp">cls</span><span class="o">=</span><span class="n">ConfigEncoder</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;config.json&#39;</span><span class="p">),</span>
<span class="s2">&quot;w&quot;</span><span class="p">,</span>
<span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="bp">cls</span><span class="o">=</span><span class="n">ConfigEncoder</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">serialize_engine</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="p">,</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
<span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="si">}</span><span class="s1">.engine&#39;</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">fn</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
<span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="si">}</span><span class="s1">_managed_weights.safetensors&#39;</span>
<span class="p">)</span>
<span class="n">serialize_managed_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">,</span> <span class="n">fn</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_dir</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s1">.engine&#39;</span><span class="p">),</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">engine_buffer</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="n">mw_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s1">_managed_weights.safetensors&#39;</span><span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="n">deserialize_managed_weights</span><span class="p">(</span>
<span class="n">mw_path</span><span class="p">)</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mw_path</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;config.json&#39;</span><span class="p">))</span>
<span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">set_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_buffer</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">,</span> <span class="nb">bytes</span><span class="p">],</span>
<span class="n">json_config_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_str</span><span class="p">(</span><span class="n">json_config_str</span><span class="p">)</span>
<span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">set_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">get_engine_version</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">str</span><span class="p">]:</span>
<span class="n">engine_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
<span class="n">config_path</span> <span class="o">=</span> <span class="n">engine_dir</span> <span class="o">/</span> <span class="s2">&quot;config.json&quot;</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="k">if</span> <span class="s1">&#39;version&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">config</span><span class="p">[</span><span class="s1">&#39;version&#39;</span><span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">optimize_model_with_config</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">PretrainedModel</span><span class="p">,</span>
<span class="n">build_config</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">):</span>
<span class="n">use_auto_parallel</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">auto_parallel_config</span><span class="o">.</span><span class="n">enabled</span>
<span class="n">gemm_swiglu_plugin</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span>
<span class="n">low_latency_gemm_swiglu_plugin</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span>
<span class="k">if</span> <span class="n">gemm_swiglu_plugin</span> <span class="ow">or</span> <span class="n">low_latency_gemm_swiglu_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fused_mlp</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;GemmSwiGLU plugin requires --use_fused_mlp flag&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">gemm_swiglu_plugin</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s2">&quot;fp8&quot;</span>
<span class="p">]</span> <span class="ow">and</span> <span class="n">low_latency_gemm_swiglu_plugin</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;fp8&quot;</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;GemmSwiGLU plugin currently has limited support: fp8 only, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;got: </span><span class="si">{</span><span class="n">gemm_swiglu_plugin</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="sa">f</span><span class="s2">&quot;got: </span><span class="si">{</span><span class="n">low_latency_gemm_swiglu_plugin</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">model</span><span class="o">.</span><span class="n">use_lora</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="p">)</span>
<span class="n">is_enc_dec</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;EncoderModel&quot;</span><span class="p">,</span> <span class="s2">&quot;DecoderModel&quot;</span><span class="p">]</span>
<span class="c1"># FusedMLP does not support RecurrentGemma FP8 currently.</span>
<span class="n">is_recurrent_gemma</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s2">&quot;RecurrentGemmaForCausalLM&quot;</span>
<span class="p">]</span>
<span class="n">is_fp8</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">optimize_model</span><span class="p">(</span>
<span class="n">model</span><span class="p">,</span>
<span class="n">share_embedding_table</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">use_ootb_moe</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">moe_plugin</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fused_mlp</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">is_enc_dec</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="p">(</span><span class="n">is_recurrent_gemma</span> <span class="ow">and</span> <span class="n">is_fp8</span><span class="p">)</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">use_auto_parallel</span><span class="p">),</span>
<span class="n">gemm_swiglu_plugin_dtype</span><span class="o">=</span><span class="n">gemm_swiglu_plugin</span><span class="p">,</span>
<span class="n">low_latency_gemm_swiglu_plugin_dtype</span><span class="o">=</span><span class="n">low_latency_gemm_swiglu_plugin</span><span class="p">,</span>
<span class="n">use_fused_rg_lru</span><span class="o">=</span><span class="n">is_recurrent_gemma</span><span class="p">,</span>
<span class="n">use_unfused_qkv_gemm</span><span class="o">=</span><span class="n">use_auto_parallel</span><span class="p">,</span>
<span class="n">use_prompt_tuning</span><span class="o">=</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">use_lora</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">max_lora_rank</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">max_lora_rank</span><span class="p">,</span>
<span class="n">use_fp8_context_fmha</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span>
<span class="p">]</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
<span class="n">fuse_fp4_quant</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fuse_fp4_quant</span><span class="p">,</span>
<span class="n">use_optimize_cross_qkv</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">use_dora</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">dora_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_enc_dec</span><span class="p">:</span>
<span class="n">model</span><span class="o">.</span><span class="n">precompute_relative_attention_bias</span><span class="p">(</span><span class="n">build_config</span><span class="p">)</span>
<span class="k">return</span> <span class="n">model</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init_max_seq_len</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor</span>
<span class="sd"> Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Extract rotary scaling which will be used for checks and default value of max_seq_len</span>
<span class="n">rotary_scaling</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="s2">&quot;rotary_scaling&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rotary_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">rotary_type</span> <span class="o">=</span> <span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;type&#39;</span><span class="p">,</span>
<span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;rope_type&#39;</span><span class="p">))</span>
<span class="n">rotary_factor</span> <span class="o">=</span> <span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s1">&#39;factor&#39;</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span> <span class="k">if</span> <span class="n">rotary_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;su&quot;</span><span class="p">,</span> <span class="s2">&quot;longrope&quot;</span><span class="p">,</span>
<span class="s2">&quot;llama3&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">rotary_factor</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;EncoderModel&quot;</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is not specified for EncoderModel, using --max_input_len.&#39;</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">==</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;EncoderModel should have same --max_input_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s2">) and --max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2">).&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># Step 1: Find the upper bound of max_seq_len</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="mi">2048</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span>
<span class="c1"># Step 2: Scale max_seq_len with rotary scaling</span>
<span class="k">if</span> <span class="n">rotary_factor</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">deduced_max_seq_len</span> <span class="o">*</span> <span class="n">rotary_factor</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is scaled to </span><span class="si">{</span><span class="n">deduced_max_seq_len</span><span class="si">}</span><span class="s1"> by rotary scaling </span><span class="si">{</span><span class="n">rotary_factor</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="c1"># Step 3: Assign the new max_seq_len</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">deduced_max_seq_len</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is not specified, using deduced value </span><span class="si">{</span><span class="n">deduced_max_seq_len</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> \
<span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">!=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">relative</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">&gt;</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">rotary_factor</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1"> is larger than max_position_embeddings </span><span class="si">{</span><span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span><span class="si">}</span><span class="s1"> * rotary scaling </span><span class="si">{</span><span class="n">rotary_factor</span><span class="si">}</span><span class="s1">, &#39;</span>
<span class="s1">&#39;the model accuracy might be affected&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">&gt;</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_input_len is </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s1"> is larger than max_seq_len </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">, clipping it to max_seq_len&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
<span class="c1"># Check and may modify max_num_tokens and opt_num_tokens (need to happen after max_seq_len is deduced)</span>
<span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">opt_num_tokens</span> <span class="o">=</span> <span class="n">check_max_num_tokens</span><span class="p">(</span>
<span class="n">max_num_tokens</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span>
<span class="n">opt_num_tokens</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span><span class="p">,</span>
<span class="n">max_batch_size</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">max_input_len</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="n">max_seq_len</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
<span class="n">max_beam_width</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">enable_context_fmha</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">multiple_profiles</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">multiple_profiles</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span> <span class="o">=</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">opt_num_tokens</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;padding removal and fMHA are both enabled, max_input_len is not required and will be ignored&#39;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;padding removal and fMHA aren</span><span class="se">\&#39;</span><span class="s1">t both enabled, max_input_len is required&#39;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">&lt;=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="s1">&#39;max_input_len should not be larger than max_seq_len&#39;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">serialize_managed_weights</span><span class="p">(</span><span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span>
<span class="n">path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">,</span>
<span class="n">metadata</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">header</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="n">metadata</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">header</span><span class="p">[</span><span class="s2">&quot;__metadata__&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">metadata</span>
<span class="n">begin</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">size</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">size</span> <span class="o">*</span> <span class="n">value</span><span class="o">.</span><span class="n">itemsize</span>
<span class="k">if</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F32&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F16&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np_bfloat16</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;BF16&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np_float8</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F8_E4M3&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I64&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I32&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I8&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported dtype: </span><span class="si">{</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">header</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;dtype&quot;</span><span class="p">:</span> <span class="n">dtype</span><span class="p">,</span>
<span class="s2">&quot;shape&quot;</span><span class="p">:</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="s2">&quot;data_offsets&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">begin</span><span class="p">,</span> <span class="n">begin</span> <span class="o">+</span> <span class="n">size</span><span class="p">],</span>
<span class="p">}</span>
<span class="n">begin</span> <span class="o">+=</span> <span class="n">size</span>
<span class="n">header_json</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">header</span><span class="p">)</span>
<span class="n">header_json_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">header_json</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;wb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Serializing </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">managed_weights</span><span class="p">)</span><span class="si">}</span><span class="s2"> managed weights to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">...&quot;</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">header_json_len</span><span class="o">.</span><span class="n">to_bytes</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">byteorder</span><span class="o">=</span><span class="s2">&quot;little&quot;</span><span class="p">))</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">header_json</span><span class="o">.</span><span class="n">encode</span><span class="p">())</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Serializing managed weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">buf</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">data</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">buf</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">deserialize_managed_weights</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">header_json_len</span> <span class="o">=</span> <span class="nb">int</span><span class="o">.</span><span class="n">from_bytes</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span> <span class="n">byteorder</span><span class="o">=</span><span class="s2">&quot;little&quot;</span><span class="p">)</span>
<span class="n">header_json</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">header_json_len</span><span class="p">)</span><span class="o">.</span><span class="n">decode</span><span class="p">()</span>
<span class="n">header</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">header_json</span><span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">info</span> <span class="ow">in</span> <span class="n">header</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;dtype&quot;</span><span class="p">]</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;shape&quot;</span><span class="p">]</span>
<span class="n">data_offsets</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;data_offsets&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F32&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F16&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;BF16&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np_bfloat16</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F8_E4M3&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np_float8</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;I64&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">int64</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;I32&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported dtype: </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">data_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">header_json_len</span> <span class="o">+</span> <span class="mi">8</span><span class="p">)</span>
<span class="n">buf</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">data_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">data_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">buf</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">managed_weights</span>
<span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">PretrainedModel</span><span class="p">,</span> <span class="n">build_config</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Engine</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Build engine from given model and optimization options specified in the build_config</span>
<span class="sd"> WARNING: this function may change the given model object state in some optimization passes</span>
<span class="sd"> to avoid cloning a model since normally the LLM models consumes large memory.</span>
<span class="sd"> Create a new fresh model object if you need to build with different options.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># avoid changing the input config</span>
<span class="n">build_config</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">build_config</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">update_kv_cache_type</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span><span class="p">)</span>
<span class="n">_init_max_seq_len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is disabled because StreamingLLM is not supported when enabling paged KV context FMHA.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span>
<span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;LlamaForCausalLM&quot;</span>
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;Gemma2ForCausalLM&quot;</span>
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;MedusaForCausalLM&quot;</span><span class="p">)):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding reduce_fusion to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding user_buffer to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">norm_quant_fusion</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;LlamaForCausalLM&quot;</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">!=</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding norm_quant_fusion to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">norm_quant_fusion</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span> <span class="ow">or</span> \
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">kv_cache_quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;max_draft_len&#39;</span><span class="p">):</span>
<span class="c1"># If model.config has &#39;max_draft_len&#39; but build_config not specified,</span>
<span class="c1"># use the value of model.config.max_draft_len to set the value of build_config.max_draft_len</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_draft_len</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;redrafter_num_beams&#39;</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;redrafter_draft_len_per_beam&#39;</span><span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EXPLICIT_DRAFT_TOKENS</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;speculative_decoding_mode is not EXPLICIT_DRAFT_TOKENS for ReDrafter model. Overwriting speculative_decoding_mode&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EXPLICIT_DRAFT_TOKENS</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Increasing max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="sa">f</span><span class="s1">&#39;by max_draft_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="s1">&#39;to account for speculative decoding implementation specifics. &#39;</span>
<span class="s1">&#39;Maximum number of generated tokens remains the same. &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_seq_len is set to </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;num_eagle_layers&#39;</span><span class="p">)</span>
<span class="n">num_eagle_layers</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_eagle_layers</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Increasing max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="sa">f</span><span class="s1">&#39;by num_eagle_layers (</span><span class="si">{</span><span class="n">num_eagle_layers</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="s1">&#39;to account for EAGLE implementation specifics. &#39;</span>
<span class="s1">&#39;Maximum number of generated tokens remains the same. &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_seq_len is set to </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">num_eagle_layers</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+=</span> <span class="n">num_eagle_layers</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="n">num_tokens</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">*</span> <span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">+</span>
<span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">&lt;</span> <span class="n">num_tokens</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_num_tokens (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="si">}</span><span class="s1">) is smaller than &#39;</span>
<span class="s1">&#39;max_batch_size * (max_draft_len + 1) = &#39;</span>
<span class="sa">f</span><span class="s1">&#39;(</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s1"> * (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1"> + 1)). &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_num_tokens is set to </span><span class="si">{</span><span class="n">num_tokens</span><span class="si">}</span><span class="s1">.&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">=</span> <span class="n">num_tokens</span>
<span class="c1"># Logics to control paged_context_fmha and fp8_context_fmha</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Context FMHA is disabled, FP8 Context FMHA and Paged Context FMHA are disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span>
<span class="p">]:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Context FMHA is disabled because it must be used together with the fp8 quantization workflow.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">89</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 context FMHA is disabled because it is only supported on Ada and Hopper Arch.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because it must be used together with fp8 KV Cache.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Context FMHA is enabled to support FP8 Paged Context FMHA.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is disabled because it doesn&#39;t work with int8 kv cache currently.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">100</span> <span class="ow">and</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">120</span><span class="p">:</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_int8_weight_only</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_int4_weight_only</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;INT8/INT4 quantization is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_and_weight_quant</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;SmoothQuant is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_channel_scaling</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_token_dynamic_scaling</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Per-channel or per-token scaling is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">optimize_model_with_config</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="n">builder</span> <span class="o">=</span> <span class="n">Builder</span><span class="p">()</span>
<span class="n">builder_config</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">create_builder_config</span><span class="p">(</span>
<span class="n">precision</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">use_refit</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">use_refit</span><span class="p">,</span>
<span class="n">timing_cache</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">input_timing_cache</span><span class="p">,</span>
<span class="n">int8</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_or_weight_quant</span><span class="p">()</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_group_scaling</span><span class="p">())</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(),</span>
<span class="n">strongly_typed</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">force_num_profiles</span><span class="p">,</span>
<span class="n">profiling_verbosity</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">profiling_verbosity</span><span class="p">,</span>
<span class="n">quant_mode</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">use_strip_plan</span><span class="p">,</span>
<span class="n">weight_sparsity</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">weight_sparsity</span><span class="p">,</span>
<span class="n">weight_streaming</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">weight_streaming</span><span class="p">,</span>
<span class="n">monitor_memory</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">monitor_memory</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">network</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">()</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span>
<span class="n">use_auto_parallel</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">auto_parallel_config</span><span class="o">.</span><span class="n">enabled</span>
<span class="n">use_weight_only</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_weight_only</span><span class="p">()</span>
<span class="n">per_group</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_group_scaling</span><span class="p">()</span>
<span class="n">use_smooth_quant</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_and_weight_quant</span><span class="p">()</span>
<span class="n">use_qserve</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_qserve_w4a8</span><span class="p">()</span>
<span class="n">use_fp8_rowwise</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">()</span>
<span class="n">disable_weight_only_quant_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">disable_weight_only_quant_plugin</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;disable_weight_only_quant_plugin&#39;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
<span class="n">use_fp8_rowwise</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">()</span>
<span class="n">use_fp4_gemm</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_nvfp4</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_fp4_gemm</span> <span class="ow">and</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">_explicitly_disable_gemm_plugin</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s1">&#39;NVFP4 quantization detected, by default enabling NVFP4 GEMM plugin. To use OOTB GEMM, please explicitly set gemm_plugin to &quot;disable&quot;&#39;</span>
<span class="p">)</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span> <span class="o">=</span> <span class="s2">&quot;nvfp4&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span><span class="p">:</span>
<span class="k">if</span> <span class="n">use_weight_only</span> <span class="ow">and</span> <span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Manage weights of weight only quant works only with plugin currently.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">use_weight_only</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="n">per_group</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">weight_only_groupwise_quant_matmul_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">weight_only_quant_matmul_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">if</span> <span class="n">use_smooth_quant</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">_use_plugin_sq</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">smooth_quant_plugins</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_smooth_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_qserve</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_qserve_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_fp8_rowwise</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_fp8_rowwise_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">nccl_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_nccl_plugin</span><span class="p">(</span><span class="n">nccl_plugin</span><span class="p">)</span>
<span class="k">with</span> <span class="n">net_guard</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="c1"># Prepare</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_named_parameters</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">())</span>
<span class="c1"># Forward</span>
<span class="n">prepare_input_args</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;max_batch_size&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="s2">&quot;max_input_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="s2">&quot;max_seq_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
<span class="s2">&quot;use_cache&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">DISABLED</span><span class="p">,</span>
<span class="s2">&quot;max_beam_width&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="s2">&quot;max_num_tokens&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span>
<span class="s2">&quot;opt_num_tokens&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span><span class="p">,</span>
<span class="s2">&quot;prompt_embedding_table_size&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span><span class="p">,</span>
<span class="s2">&quot;max_draft_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="p">,</span>
<span class="s2">&quot;speculative_decoding_draft_tokens_external&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span>
<span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">DRAFT_TOKENS_EXTERNAL</span><span class="p">,</span>
<span class="s2">&quot;gather_context_logits&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">,</span>
<span class="s2">&quot;lora_target_modules&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="p">}</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;DecoderModel&quot;</span> <span class="ow">or</span> <span class="s2">&quot;mllama&quot;</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">prepare_input_args</span><span class="p">[</span><span class="s2">&quot;max_seq_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;max_decoder_input_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;max_encoder_input_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_encoder_input_len</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;WhisperEncoder&quot;</span><span class="p">:</span>
<span class="n">prepare_input_args</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;max_batch_size&quot;</span><span class="p">:</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;spec_decoding_is_generation_length_variable&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">&lt;=</span> <span class="mi">512</span><span class="p">,</span> <span class="s2">&quot;Max batch size &gt; 512 is not supported for EAGLE&quot;</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">&lt;=</span> <span class="mi">256</span><span class="p">,</span> <span class="s2">&quot;Max draft len &gt; 256 is not supported for EAGLE&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">LOOKAHEAD_DECODING</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;spec_decoding_is_generation_length_variable&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;Qwen2VLForConditionalGeneration&quot;</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;Qwen2VLModel&quot;</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s1">&#39;mrope_rotary_cos_sin_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_embedding_dim</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is required for EAGLE. Turning it on&quot;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">prepare_inputs</span><span class="p">(</span><span class="o">**</span><span class="n">prepare_input_args</span><span class="p">)</span>
<span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">enable_debug_output</span><span class="p">:</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">named_network_outputs</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;DecoderModel&quot;</span><span class="p">:</span>
<span class="n">optimize</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_auto_parallel</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">auto_parallel_config</span>
<span class="n">config</span><span class="o">.</span><span class="n">builder_flags</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">flags</span>
<span class="n">sharded_networks</span> <span class="o">=</span> <span class="n">auto_parallel</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">config</span><span class="p">)</span>
<span class="n">network</span> <span class="o">=</span> <span class="n">sharded_networks</span><span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">auto_parallel_config</span><span class="o">.</span><span class="n">debug_mode</span><span class="p">:</span>
<span class="n">mapping</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">auto_parallel_config</span><span class="p">[</span><span class="s2">&quot;mapping&quot;</span><span class="p">]</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span> <span class="o">=</span> <span class="n">mapping</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">with</span> <span class="n">net_guard</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="n">network</span><span class="o">.</span><span class="n">to_onnx</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span><span class="p">)</span>
<span class="c1"># Network -&gt; Engine</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Total time of constructing network from module object </span><span class="si">{</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">tic</span><span class="si">}</span><span class="s2"> seconds&quot;</span>
<span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">engine</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">dry_run</span> <span class="k">else</span> <span class="n">builder</span><span class="o">.</span><span class="n">build_engine</span><span class="p">(</span>
<span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
<span class="n">engine_config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">,</span> <span class="n">__version__</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">output_timing_cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">save_timing_cache</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">output_timing_cache</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">ok</span><span class="p">,</span> <span class="s2">&quot;Failed to save timing cache.&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">psutil</span>
<span class="c1"># Get the current process</span>
<span class="n">current_process</span> <span class="o">=</span> <span class="n">psutil</span><span class="o">.</span><span class="n">Process</span><span class="p">()</span>
<span class="c1"># Get resource usage for the current process (self)</span>
<span class="n">rusage_s</span> <span class="o">=</span> <span class="n">current_process</span><span class="o">.</span><span class="n">memory_info</span><span class="p">()</span>
<span class="c1"># Get resource usage for all child processes</span>
<span class="n">children</span> <span class="o">=</span> <span class="n">current_process</span><span class="o">.</span><span class="n">children</span><span class="p">(</span><span class="n">recursive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">rusage_c</span> <span class="o">=</span> <span class="p">[</span><span class="n">child</span><span class="o">.</span><span class="n">memory_info</span><span class="p">()</span> <span class="k">for</span> <span class="n">child</span> <span class="ow">in</span> <span class="n">children</span><span class="p">]</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Build phase peak memory: </span><span class="si">{</span><span class="n">rusage_s</span><span class="o">.</span><span class="n">rss</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> MB, children: </span><span class="si">{</span><span class="nb">sum</span><span class="p">([</span><span class="n">ru</span><span class="o">.</span><span class="n">rss</span><span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="n">ru</span><span class="w"> </span><span class="ow">in</span><span class="w"> </span><span class="n">rusage_c</span><span class="p">])</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> MB&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">Engine</span><span class="p">(</span><span class="n">engine_config</span><span class="p">,</span> <span class="n">engine</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
</pre></div>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
</div>
</footer>
</div>
<div class="bd-sidebar-secondary"></div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item">
<a class="footer-brand logo" href="https://www.nvidia.com">
<img src="../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
<img src="../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
</a></div>
<div class="footer-item">
<div class="footer-links">
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Manage My Privacy</a>
|
<a class="external" href="https://www.nvidia.com/en-us/preferences/start/">Do Not Sell or Share My Data</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
</div>
</div>
<div class="footer-item">
<p class="copyright">
Copyright © 2025, NVidia.
<br/>
</p>
</div>
<div class="footer-item">
<div class="extra_footer">
<p>Last updated on September 02, 2025.</p>
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/e81c50d">e81c50d</a>.</p>
</div></div>
</div>
</div>
</footer>
</body>
</html>