TensorRT-LLMs/_modules/tensorrt_llm/builder.html
2026-01-08 05:44:03 +00:00

1942 lines
253 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en" data-content_root="../../" >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>tensorrt_llm.builder &#8212; TensorRT LLM</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
</script>
<!--
this give us a css class that will be invisible only if js is disabled
-->
<noscript>
<style>
.pst-js-only { display: none !important; }
</style>
</noscript>
<!-- Loaded before other Sphinx assets -->
<link href="../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link href="../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
<link rel="stylesheet" type="text/css" href="../../_static/pygments.css?v=8f2a1f02" />
<link rel="stylesheet" type="text/css" href="../../_static/styles/nvidia-sphinx-theme.css?v=933278ad" />
<link rel="stylesheet" type="text/css" href="../../_static/copybutton.css?v=76b2166b" />
<link rel="stylesheet" type="text/css" href="../../_static/autodoc_pydantic.css" />
<link rel="stylesheet" type="text/css" href="../../_static/togglebutton.css?v=13237357" />
<link rel="stylesheet" type="text/css" href="../../_static/config_selector.css?v=e17d8078" />
<link rel="stylesheet" type="text/css" href="../../_static/custom.css?v=19d20f17" />
<!-- So that users can add custom icons -->
<script src="../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
<link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
<script src="../../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../../_static/doctools.js?v=9a2dae69"></script>
<script src="../../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../../_static/copybutton.js?v=65e89d2a"></script>
<script src="../../_static/config_selector.js?v=aaf6cd4a"></script>
<script>let toggleHintShow = 'Click to show';</script>
<script>let toggleHintHide = 'Click to hide';</script>
<script>let toggleOpenOnPrint = 'true';</script>
<script src="../../_static/togglebutton.js?v=4a39c7ea"></script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/builder';</script>
<script>
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.2.0rc7';
DOCUMENTATION_OPTIONS.show_version_warning_banner =
false;
</script>
<link rel="icon" href="../../_static/favicon.png"/>
<link rel="index" title="Index" href="../../genindex.html" />
<link rel="search" title="Search" href="../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta name="docsearch:language" content="en"/>
<meta name="docsearch:version" content="1.2.0rc7" />
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
<div id="pst-scroll-pixel-helper"></div>
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
<dialog id="pst-search-dialog">
<form class="bd-search d-flex align-items-center"
action="../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form>
</dialog>
<div class="pst-async-banner-revealer d-none">
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
</div>
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
<div class="bd-header__inner bd-page-width">
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
<span class="fa-solid fa-bars"></span>
</button>
<div class="col-lg-3 navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-2"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-2"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-2"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-2">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
<span class="search-button__default-text">Search</span>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
</button>
</div>
</div>
</header>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<dialog id="pst-primary-sidebar-modal"></dialog>
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
<a class="navbar-brand logo" href="../../index.html">
<img src="../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
<img src="../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
<p class="title logo__title">TensorRT LLM</p>
</a>
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item">
<div class="version-switcher__container dropdown pst-js-only">
<button id="pst-version-switcher-button-3"
type="button"
class="version-switcher__button btn btn-sm dropdown-toggle"
data-bs-toggle="dropdown"
aria-haspopup="listbox"
aria-controls="pst-version-switcher-list-3"
aria-label="Version switcher list"
>
Choose version <!-- this text may get changed later by javascript -->
<span class="caret"></span>
</button>
<div id="pst-version-switcher-list-3"
class="version-switcher__menu dropdown-menu list-group-flush py-0"
role="listbox" aria-labelledby="pst-version-switcher-button-3">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item">
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
</button></div>
</div>
</div>
<div class="sidebar-primary-items__start sidebar-primary__section">
<div class="sidebar-primary-item">
<nav class="bd-docs-nav bd-links"
aria-label="Table of Contents">
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../installation/index.html">Installation</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../installation/containers.html">Pre-built release container images on NGC</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
<li class="toctree-l2"><a class="reference internal" href="../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference.html">Generate text</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sparse_attention.html">Sparse Attention</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_kv_cache_offloading.html">KV Cache Offloading</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
</ul>
</details></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../examples/aiperf_client.html">Aiperf Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/aiperf_client_for_multimodal.html">Aiperf Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client.html">Curl Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_completion_client.html">Curl Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/curl_responses_client.html">Curl Responses Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../examples/openai_responses_client.html">OpenAI Responses Client</a></li>
</ul>
</details></li>
<li class="toctree-l1"><a class="reference internal" href="../../examples/dynamo_k8s_example.html">Dynamo K8s Example</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../deployment-guide/index.html">Model Recipes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-deepseek-r1-on-trtllm.html">Deployment Guide for DeepSeek R1 on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-llama3.3-70b-on-trtllm.html">Deployment Guide for Llama3.3 70B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-llama4-scout-on-trtllm.html">Deployment Guide for Llama4 Scout 17B on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-gpt-oss-on-trtllm.html">Deployment Guide for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-qwen3-on-trtllm.html">Deployment Guide for Qwen3 on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-qwen3-next-on-trtllm.html">Deployment Guide for Qwen3 Next on TensorRT LLM - Blackwell &amp; Hopper Hardware</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../deployment-guide/deployment-guide-for-kimi-k2-thinking-on-trtllm.html">Deployment Guide for Kimi K2 Thinking on TensorRT LLM - Blackwell</a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Models</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../models/supported-models.html">Supported Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../models/adding-new-model.html">Adding a New Model</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">CLI Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-bench.html">trtllm-bench</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../commands/trtllm-eval.html">trtllm-eval</a></li>
<li class="toctree-l1 has-children"><a class="reference internal" href="../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
</ul>
</details></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">API Reference</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/index.html">LLM API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../llm-api/reference.html">API Reference</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Features</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../features/feature-combination-matrix.html">Feature Combination Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/disagg-serving.html">Disaggregated Serving</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/kvcache.html">KV Cache System</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/long-sequence.html">Long Sequences</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/lora.html">LoRA (Low-Rank Adaptation)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/multi-modality.html">Multimodal Support in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/overlap-scheduler.html">Overlap Scheduler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/paged-attention-ifb-scheduler.html">Paged Attention, IFB, and Request Scheduling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/parallel-strategy.html">Parallelism in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/sampling.html">Sampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/additional-outputs.html">Additional Outputs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/guided-decoding.html">Guided Decoding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/speculative-decoding.html">Speculative Decoding</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/checkpoint-loading.html">Checkpoint Loading</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/auto_deploy/auto-deploy.html">AutoDeploy (Beta)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/ray-orchestrator.html">Ray Orchestrator (Prototype)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/torch_compile_and_piecewise_cuda_graph.html">Torch Compile &amp; Piecewise CUDA Graph</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/helix.html">Helix Parallelism</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../features/kv-cache-connector.html">KV Cache Connector</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/overview.html">Architecture Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/perf-analysis.html">Performance Analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/perf-benchmarking.html">TensorRT LLM Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/ci-overview.html">Continuous Integration Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/dev-containers.html">Using Dev Containers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/api-change.html">LLM API Change Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../developer-guide/kv-transfer.html">Introduction to KV Cache Transmission</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog11_GPT_OSS_Eagle3.html">Running GPT-OSS-120B with Eagle3 Speculative Decoding on GB200/B200 (TensorRT LLM)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog12_Combining_Guided_Decoding_and_Speculative_Decoding.html">Combining Guided Decoding and Speculative Decoding: Making CPU and GPU Cooperate Seamlessly</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog13_Inference_Time_Compute_Implementation_in_TensorRT-LLM.html">Inference Time Compute Implementation in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html">Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT LLM server</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-GramSpeculativeDecodingin TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT LLM (Part 2: Performance Status and Optimization)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.html">How to get best performance on DeepSeek-R1 in TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Quick Links</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/releases">Releases</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM">Github Code</a></li>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap">Roadmap</a></li>
</ul>
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
<ul class="nav bd-sidenav">
<li class="toctree-l1"><a class="reference internal" href="../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
</ul>
</div>
</nav></div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
</div>
<main id="main-content" class="bd-main" role="main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article d-print-none">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumb" class="d-print-none">
<ul class="bd-breadcrumbs">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.builder</span></li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article">
<h1>Source code for tensorrt_llm.builder</h1><div class="highlight"><pre>
<span></span><span class="c1"># SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &amp; AFFILIATES. All rights reserved.</span>
<span class="c1"># SPDX-License-Identifier: Apache-2.0</span>
<span class="c1">#</span>
<span class="c1"># Licensed under the Apache License, Version 2.0 (the &quot;License&quot;);</span>
<span class="c1"># you may not use this file except in compliance with the License.</span>
<span class="c1"># You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">json</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">shutil</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">tensorrt</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">trt</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">pydantic</span><span class="w"> </span><span class="kn">import</span> <span class="n">BaseModel</span><span class="p">,</span> <span class="n">Field</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._common</span><span class="w"> </span><span class="kn">import</span> <span class="n">_is_building</span><span class="p">,</span> <span class="n">check_max_num_tokens</span><span class="p">,</span> <span class="n">serialize_engine</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">._utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">get_sm_version</span><span class="p">,</span> <span class="n">np_bfloat16</span><span class="p">,</span> <span class="n">np_float8</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">,</span>
<span class="n">to_json_file</span><span class="p">,</span> <span class="n">trt_gte</span><span class="p">)</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.functional</span><span class="w"> </span><span class="kn">import</span> <span class="n">PositionEmbeddingType</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.graph_rewriting</span><span class="w"> </span><span class="kn">import</span> <span class="n">optimize</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.llmapi.kv_cache_type</span><span class="w"> </span><span class="kn">import</span> <span class="n">KVCacheType</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.logger</span><span class="w"> </span><span class="kn">import</span> <span class="n">logger</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.lora_helper</span><span class="w"> </span><span class="kn">import</span> <span class="n">LoraConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.models</span><span class="w"> </span><span class="kn">import</span> <span class="n">PretrainedConfig</span><span class="p">,</span> <span class="n">PretrainedModel</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.models.modeling_utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">SpeculativeDecodingMode</span><span class="p">,</span> <span class="n">optimize_model</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.network</span><span class="w"> </span><span class="kn">import</span> <span class="n">Network</span><span class="p">,</span> <span class="n">net_guard</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.plugin</span><span class="w"> </span><span class="kn">import</span> <span class="n">PluginConfig</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.quantization</span><span class="w"> </span><span class="kn">import</span> <span class="n">QuantAlgo</span><span class="p">,</span> <span class="n">QuantMode</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.version</span><span class="w"> </span><span class="kn">import</span> <span class="n">__version__</span>
<span class="k">class</span><span class="w"> </span><span class="nc">ConfigEncoder</span><span class="p">(</span><span class="n">json</span><span class="o">.</span><span class="n">JSONEncoder</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="nf">default</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">obj</span><span class="p">):</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s1">&#39;model_dump&#39;</span><span class="p">):</span>
<span class="c1"># Handle Pydantic models (including DecodingBaseConfig and subclasses)</span>
<span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">model_dump</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">&#39;json&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">default</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
<span class="k">class</span><span class="w"> </span><span class="nc">BuilderConfig</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="c1"># intentionally use **kwargs, user should never call this ctor directly,</span>
<span class="c1"># use Builder.create_builder_config() instead</span>
<span class="k">pass</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trt_builder_config</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder_config</span> <span class="o">=</span> <span class="n">trt_builder_config</span>
<span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">trt_builder_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IBuilderConfig</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder_config</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;return a dict with keys</span>
<span class="sd"> {</span>
<span class="sd"> &quot;builder_config&quot;: {</span>
<span class="sd"> # all key values set by the _init function</span>
<span class="sd"> },</span>
<span class="sd"> &quot;plugin_config&quot;: {</span>
<span class="sd"> # the network plugin_config (if any) attached to this BuilderConfig object</span>
<span class="sd"> # inside the Builder.build_engine</span>
<span class="sd"> }</span>
<span class="sd"> }</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;builder_config&#39;</span><span class="p">:</span> <span class="p">{}}</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;_trt_builder_config&#39;</span><span class="p">,</span> <span class="s1">&#39;plugin_config&#39;</span><span class="p">]:</span>
<span class="n">config</span><span class="p">[</span><span class="s1">&#39;builder_config&#39;</span><span class="p">][</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="fm">__getattribute__</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s1">&#39;plugin_config&#39;</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">PluginConfig</span><span class="p">),</span> \
<span class="sa">f</span><span class="s2">&quot;Found unexpected plugin_config object with type: </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">config</span><span class="p">[</span><span class="s1">&#39;plugin_config&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">model_dump</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">&quot;json&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">config</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Builder</span><span class="p">():</span>
<span class="n">_ALLOWED_PRECISIONS</span> <span class="o">=</span> <span class="p">[</span>
<span class="s1">&#39;float32&#39;</span><span class="p">,</span> <span class="s1">&#39;float16&#39;</span><span class="p">,</span> <span class="s1">&#39;bfloat16&#39;</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">FLOAT</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">BF16</span>
<span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Builder</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="kc">True</span>
<span class="nd">@property</span>
<span class="k">def</span><span class="w"> </span><span class="nf">trt_builder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">Builder</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trt_builder</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_network</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Network</span><span class="p">:</span>
<span class="n">explicit_batch_flag</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Explicit batch flag will be deprecated in TRT 10</span>
<span class="k">if</span> <span class="s2">&quot;EXPLICIT_BATCH&quot;</span> <span class="ow">in</span> <span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">__members__</span><span class="o">.</span><span class="n">keys</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">explicit_batch_flag</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">EXPLICIT_BATCH</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="k">return</span> <span class="n">Network</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">(</span>
<span class="n">explicit_batch_flag</span>
<span class="o">|</span> <span class="p">(</span><span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="nb">int</span><span class="p">(</span>
<span class="n">trt</span><span class="o">.</span><span class="n">NetworkDefinitionCreationFlag</span><span class="o">.</span><span class="n">STRONGLY_TYPED</span><span class="p">))))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">Network</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">(</span><span class="n">explicit_batch_flag</span><span class="p">))</span>
<span class="k">def</span><span class="w"> </span><span class="nf">create_builder_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">precision</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="p">],</span>
<span class="n">timing_cache</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">,</span>
<span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">tensor_parallel</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">use_refit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">int8</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">strongly_typed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">profiling_verbosity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;layer_names_only&quot;</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">weight_streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">precision_constraints</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;obey&quot;</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">BuilderConfig</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39; @brief Create a builder config with given precisions and timing cache</span>
<span class="sd"> @param precision: one of allowed precisions, defined in Builder._ALLOWED_PRECISIONS</span>
<span class="sd"> @param timing_cache: a timing cache object or a path to a timing cache file</span>
<span class="sd"> @param tensor_parallel: number of GPUs used for tensor parallel</span>
<span class="sd"> @param kwargs: any other arguments users would like to attach to the config object as attributes</span>
<span class="sd"> @param refit: set to accelerate multi-gpu building, build engine for 1 gpu and refit for the others</span>
<span class="sd"> @param int8: whether to build with int8 enabled or not. Can&#39;t be used together with refit option</span>
<span class="sd"> @return: A BuilderConfig object, return None if failed</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="ow">and</span> <span class="n">strongly_typed</span>
<span class="n">quant_mode</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;quant_mode&quot;</span><span class="p">,</span> <span class="n">QuantMode</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">strongly_typed</span> <span class="ow">and</span> <span class="n">precision</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ALLOWED_PRECISIONS</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;precision should be one of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_ALLOWED_PRECISIONS</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_builder_config</span><span class="p">()</span>
<span class="k">if</span> <span class="n">weight_streaming</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">WEIGHT_STREAMING</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">:</span>
<span class="n">fp8</span> <span class="o">=</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_qdq</span><span class="p">()</span> <span class="ow">or</span> <span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">()</span>
<span class="k">if</span> <span class="n">precision</span> <span class="o">==</span> <span class="s1">&#39;float16&#39;</span> <span class="ow">or</span> <span class="n">precision</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">HALF</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">FP16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">precision</span> <span class="o">==</span> <span class="s1">&#39;bfloat16&#39;</span> <span class="ow">or</span> <span class="n">precision</span> <span class="o">==</span> <span class="n">trt</span><span class="o">.</span><span class="n">DataType</span><span class="o">.</span><span class="n">BF16</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">BF16</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">if</span> <span class="n">int8</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">INT8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">fp8</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">FP8</span><span class="p">)</span>
<span class="k">if</span> <span class="n">precision_constraints</span> <span class="o">==</span> <span class="s1">&#39;obey&#39;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">OBEY_PRECISION_CONSTRAINTS</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_refit</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">REFIT</span><span class="p">)</span>
<span class="c1"># Use fine-grained refit when strip plan is enabled in TRT10.2+.</span>
<span class="k">if</span> <span class="n">use_strip_plan</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">REFIT_INDIVIDUAL</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_strip_plan</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">STRIP_PLAN</span><span class="p">)</span>
<span class="c1"># Set TRT Engine profiling verbosity</span>
<span class="k">if</span> <span class="n">profiling_verbosity</span> <span class="o">==</span> <span class="s2">&quot;detailed&quot;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">DETAILED</span>
<span class="k">elif</span> <span class="n">profiling_verbosity</span> <span class="o">==</span> <span class="s2">&quot;none&quot;</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">NONE</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">profiling_verbosity</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">ProfilingVerbosity</span><span class="o">.</span><span class="n">LAYER_NAMES_ONLY</span>
<span class="c1"># set timing cache</span>
<span class="n">cache</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="n">timing_cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># use given cache</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">):</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">timing_cache</span>
<span class="c1"># read cache from file</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span>
<span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">timing_cache</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">create_timing_cache</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Invalid timing cache, using freshly created one&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cache</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">create_timing_cache</span><span class="p">(</span><span class="sa">b</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="c1"># When user does not given any existing cache, internally always created one</span>
<span class="c1"># so the cache should never None here</span>
<span class="k">assert</span> <span class="n">cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cache</span><span class="p">,</span> <span class="n">trt</span><span class="o">.</span><span class="n">ITimingCache</span><span class="p">)</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_timing_cache</span><span class="p">(</span><span class="n">cache</span><span class="p">,</span> <span class="n">ignore_mismatch</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
<span class="c1"># set weight sparsity</span>
<span class="n">weight_sparsity</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;weight_sparsity&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">weight_sparsity</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">SPARSE_WEIGHTS</span><span class="p">)</span>
<span class="c1"># TODO: remove this constraint after trt 10.6 is integrated</span>
<span class="k">if</span> <span class="n">trt_gte</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">6</span><span class="p">):</span>
<span class="c1"># set monitor memory</span>
<span class="n">monitor_memory</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;monitor_memory&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">monitor_memory</span><span class="p">:</span>
<span class="n">config</span><span class="o">.</span><span class="n">set_flag</span><span class="p">(</span><span class="n">trt</span><span class="o">.</span><span class="n">BuilderFlag</span><span class="o">.</span><span class="n">MONITOR_MEMORY</span><span class="p">)</span>
<span class="k">return</span> <span class="n">BuilderConfig</span><span class="p">()</span><span class="o">.</span><span class="n">_init</span><span class="p">(</span><span class="n">config</span><span class="p">,</span>
<span class="n">precision</span><span class="o">=</span><span class="n">precision</span><span class="p">,</span>
<span class="n">tensor_parallel</span><span class="o">=</span><span class="n">tensor_parallel</span><span class="p">,</span>
<span class="n">use_refit</span><span class="o">=</span><span class="n">use_refit</span><span class="p">,</span>
<span class="n">int8</span><span class="o">=</span><span class="n">int8</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="o">=</span><span class="n">force_num_profiles</span><span class="p">,</span>
<span class="n">strongly_typed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="o">=</span><span class="n">use_strip_plan</span><span class="p">,</span>
<span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_add_optimization_profile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span> <span class="n">BuilderConfig</span><span class="p">)</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">input_tensors</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">_inputs</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensors</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;There are no inputs in the network!&quot;</span><span class="p">)</span>
<span class="k">return</span>
<span class="n">num_profiles</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">input_tensors</span><span class="o">.</span><span class="n">values</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span>
<span class="n">force_num_profiles</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span> <span class="s2">&quot;force_num_profiles&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_profiles</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Adding optimization profile </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s1">/</span><span class="si">{</span><span class="n">num_profiles</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">profile</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">create_optimization_profile</span><span class="p">()</span>
<span class="k">for</span> <span class="n">input_name</span> <span class="ow">in</span> <span class="n">input_tensors</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensors</span><span class="p">[</span><span class="n">input_name</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">shape_profile</span> <span class="o">=</span> <span class="n">input_tensors</span><span class="p">[</span><span class="n">input_name</span><span class="p">]</span><span class="o">.</span><span class="n">profiles</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
<span class="n">min_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">min</span><span class="p">]</span>
<span class="n">opt_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">opt</span><span class="p">]</span>
<span class="n">max_shape</span> <span class="o">=</span> <span class="p">[</span><span class="o">*</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">max</span><span class="p">]</span>
<span class="n">profile</span><span class="o">.</span><span class="n">set_shape</span><span class="p">(</span><span class="n">input_name</span><span class="p">,</span> <span class="n">min_shape</span><span class="p">,</span> <span class="n">opt_shape</span><span class="p">,</span> <span class="n">max_shape</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">input_name</span><span class="si">}</span><span class="s1">, min: </span><span class="si">{</span><span class="n">min_shape</span><span class="si">}</span><span class="s1">, opt: </span><span class="si">{</span><span class="n">opt_shape</span><span class="si">}</span><span class="s1">, max: </span><span class="si">{</span><span class="n">max_shape</span><span class="si">}</span><span class="s1">, dimension names: </span><span class="si">{</span><span class="n">shape_profile</span><span class="o">.</span><span class="n">dimension_names</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">ret</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">add_optimization_profile</span><span class="p">(</span>
<span class="n">profile</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Added optimization profile: #</span><span class="si">{</span><span class="n">ret</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">force_num_profiles</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">i</span> <span class="o">+</span> <span class="mi">1</span>
<span class="p">)</span> <span class="o">==</span> <span class="n">force_num_profiles</span> <span class="ow">and</span> <span class="n">force_num_profiles</span> <span class="o">&lt;</span> <span class="n">num_profiles</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Only adding </span><span class="si">{</span><span class="n">force_num_profiles</span><span class="si">}</span><span class="s2"> profiles instead of </span><span class="si">{</span><span class="n">num_profiles</span><span class="si">}</span><span class="s2">.&quot;</span>
<span class="p">)</span>
<span class="k">break</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">_validate_named_dimensions</span><span class="p">(</span>
<span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span>
<span class="p">),</span> <span class="s2">&quot;Validation of the tensor dimension ranges failed, please check the dimension ranges, find the offensive tensor and dimension name in above the error log&quot;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_named_dimensions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> For each profile, validate that the named dimensions of different input tensors in this profile all have same range.</span>
<span class="sd"> TRT will validate the same condition, validate it earlier to make sure the modeling in TensorRT LLM are correct and</span>
<span class="sd"> makes the error msg more user friendly.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">valid</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">for</span> <span class="n">profile_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span>
<span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="p">):</span>
<span class="n">dimension_to_range</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">input_name</span><span class="p">,</span> <span class="n">input_tensor</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">_inputs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="c1"># it&#39;s legal that a Tensor does not have dim_range?</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">profiles</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">profile</span> <span class="o">=</span> <span class="n">input_tensor</span><span class="o">.</span><span class="n">profiles</span><span class="p">[</span><span class="n">profile_idx</span><span class="p">]</span>
<span class="k">for</span> <span class="n">dim_idx</span><span class="p">,</span> <span class="n">dim_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">profile</span><span class="o">.</span><span class="n">dimension_names</span><span class="p">):</span>
<span class="k">if</span> <span class="n">dim_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dimension_to_range</span><span class="p">:</span>
<span class="n">dimension_to_range</span><span class="p">[</span><span class="n">dim_name</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
<span class="nb">min</span><span class="p">,</span> <span class="n">opt</span><span class="p">,</span> <span class="nb">max</span> <span class="o">=</span> <span class="n">profile</span><span class="o">.</span><span class="n">min</span><span class="p">[</span><span class="n">dim_idx</span><span class="p">],</span> <span class="n">profile</span><span class="o">.</span><span class="n">opt</span><span class="p">[</span>
<span class="n">dim_idx</span><span class="p">],</span> <span class="n">profile</span><span class="o">.</span><span class="n">max</span><span class="p">[</span><span class="n">dim_idx</span><span class="p">]</span>
<span class="n">dimension_to_range</span><span class="p">[</span><span class="n">dim_name</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
<span class="p">(</span><span class="n">input_name</span><span class="p">,</span> <span class="p">(</span><span class="nb">min</span><span class="p">,</span> <span class="n">opt</span><span class="p">,</span> <span class="nb">max</span><span class="p">)))</span>
<span class="k">for</span> <span class="n">dim</span><span class="p">,</span> <span class="n">ranges</span> <span class="ow">in</span> <span class="n">dimension_to_range</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">unique_ranges</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">r</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">ranges</span><span class="p">])</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Validating dimension:</span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2">, ranges for this dim are:</span><span class="si">{</span><span class="n">unique_ranges</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">unique_ranges</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Found illegal dimension setting for profile </span><span class="si">{</span><span class="n">profile_idx</span><span class="si">}</span><span class="s2">, dimension name is: </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s2">&quot;Offensive tensors which have this dimension are:</span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">+</span>
<span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">r</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">dim</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">r</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">ranges</span><span class="p">]))</span>
<span class="n">valid</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">return</span> <span class="n">valid</span>
<span class="nd">@_is_building</span>
<span class="k">def</span><span class="w"> </span><span class="nf">refit_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> @brief: Refit one TensorRT engine using weights from the network,</span>
<span class="sd"> user should guarantee that the engine is built with REFIT flag, and the network has the same structure with the engine.</span>
<span class="sd"> @param engine_buffer: A serialized TensorRT engine.</span>
<span class="sd"> @param network: Network object.</span>
<span class="sd"> @return: A serialized TRT engine if refit successfully, None otherwise</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;Refit TRT engine&#39;</span><span class="p">)</span>
<span class="n">runtime</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Runtime</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="n">runtime</span><span class="o">.</span><span class="n">deserialize_cuda_engine</span><span class="p">(</span><span class="n">engine_buffer</span><span class="p">)</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># Refit engine</span>
<span class="n">refitter</span> <span class="o">=</span> <span class="n">trt</span><span class="o">.</span><span class="n">Refitter</span><span class="p">(</span><span class="n">engine</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">trt_logger</span><span class="p">)</span>
<span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">refitter</span><span class="o">.</span><span class="n">set_named_weights</span><span class="p">(</span>
<span class="n">name</span><span class="p">,</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">()):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to refit weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
<span class="s1">&#39;Please set named parameters before building multiple engines.&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">refitter</span><span class="o">.</span><span class="n">refit_cuda_engine</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s1">&#39;Failed to refit engine.&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Total time of refitting </span><span class="si">{</span><span class="n">engine</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">serialized_engine</span> <span class="o">=</span> <span class="n">engine</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span>
<span class="k">return</span> <span class="n">serialized_engine</span>
<span class="nd">@_is_building</span>
<span class="k">def</span><span class="w"> </span><span class="nf">build_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
<span class="n">network</span><span class="p">:</span> <span class="n">Network</span><span class="p">,</span>
<span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span>
<span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;</span>
<span class="sd"> @brief: Build one TensorRT engine from the network.</span>
<span class="sd"> @param network: Network object.</span>
<span class="sd"> @param builder_config: BuilderConfig object.</span>
<span class="sd"> @return: A serialized TRT engine.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">Network</span><span class="p">)</span>
<span class="n">builder_config</span><span class="o">.</span><span class="n">plugin_config</span> <span class="o">=</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span>
<span class="k">if</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_add_optimization_profile</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Total optimization profiles added: </span><span class="si">{</span><span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">num_optimization_profiles</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">engine</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># Rename weights</span>
<span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">managed_parameters</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">is_managed</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;managed_weights should be provided when enabled&quot;</span>
<span class="n">managed_parameters</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">param</span><span class="o">.</span><span class="n">set_name</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">network</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">_get_weights</span><span class="p">(</span><span class="n">network</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">is_buffer</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Parameter </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">param</span><span class="o">.</span><span class="n">raw_value</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">param</span><span class="o">.</span><span class="n">raw_value</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2"> was created&quot;</span>
<span class="s2">&quot; but unused in forward method, so not materialized to TRT network&quot;</span>
<span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">param</span><span class="o">.</span><span class="n">set_name</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">network</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to set weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="c1"># This mark_weights_refittable has no side effect when refit_individual is not enabled.</span>
<span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">mark_weights_refittable</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
<span class="n">network</span><span class="o">.</span><span class="n">_fill_weights</span><span class="p">()</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Total time to initialize the weights in network </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="c1"># Build engine</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Build TensorRT engine </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="n">tik</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">engine</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trt_builder</span><span class="o">.</span><span class="n">build_serialized_network</span><span class="p">(</span>
<span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="p">,</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Engine building failed, please check the error log.&#39;</span>
<span class="n">tok</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s1">&#39;%H:%M:%S&#39;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">gmtime</span><span class="p">(</span><span class="n">tok</span> <span class="o">-</span> <span class="n">tik</span><span class="p">))</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Total time of building </span><span class="si">{</span><span class="n">network</span><span class="o">.</span><span class="n">trt_network</span><span class="o">.</span><span class="n">name</span><span class="si">}</span><span class="s1">: </span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">network</span><span class="o">.</span><span class="n">named_parameters</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">managed_parameters</span><span class="p">:</span>
<span class="n">name</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">name</span>
<span class="n">value</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="n">param</span><span class="o">.</span><span class="n">_value</span>
<span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Failed to get weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">continue</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">need_transpose</span><span class="p">:</span>
<span class="c1"># MOE has ndim=3 and uses plugin, no need to transpose</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># WAR for bug 4641821</span>
<span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">engine</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save_timing_cache</span><span class="p">(</span><span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span> <span class="n">out_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Serialize timing cache of given builder config to file specified by out_path</span>
<span class="sd"> return True if the cache is successfully serialized, False otherwise</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">cache</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">trt_builder_config</span><span class="o">.</span><span class="n">get_timing_cache</span><span class="p">()</span>
<span class="k">if</span> <span class="n">cache</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;No timing cache found in the given builder config, skip saving.&#39;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">with</span> <span class="n">cache</span><span class="o">.</span><span class="n">serialize</span><span class="p">()</span> <span class="k">as</span> <span class="n">buffer</span><span class="p">:</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">out_path</span><span class="p">,</span> <span class="s2">&quot;wb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">buffer</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
<span class="n">os</span><span class="o">.</span><span class="n">fsync</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Timing cache serialized to </span><span class="si">{</span><span class="n">out_path</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save_config</span><span class="p">(</span><span class="n">builder_config</span><span class="p">:</span> <span class="n">BuilderConfig</span><span class="p">,</span> <span class="n">config_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">builder_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="n">to_json_file</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">config_path</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Config saved to </span><span class="si">{</span><span class="n">config_path</span><span class="si">}</span><span class="s1">.&#39;</span><span class="p">)</span>
<div class="viewcode-block" id="BuildConfig">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig">[docs]</a>
<span class="k">class</span><span class="w"> </span><span class="nc">BuildConfig</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Configuration class for TensorRT LLM engine building parameters.</span>
<span class="sd"> This class contains all the configuration parameters needed to build a TensorRT LLM engine,</span>
<span class="sd"> including sequence length limits, batch sizes, optimization settings, and various features.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">max_input_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum length of input sequences.&quot;</span><span class="p">)</span>
<span class="n">max_seq_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;The maximum possible sequence length for a single request, including both input and generated &quot;</span>
<span class="s2">&quot;output tokens.&quot;</span><span class="p">)</span>
<span class="n">opt_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Optimal batch size for engine optimization.&quot;</span><span class="p">)</span>
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum batch size the engine can handle.&quot;</span><span class="p">)</span>
<span class="n">max_beam_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum beam width for beam search decoding.&quot;</span><span class="p">)</span>
<span class="n">max_num_tokens</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">8192</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum number of batched input tokens after padding is &quot;</span>
<span class="s2">&quot;removed in each batch.&quot;</span><span class="p">)</span>
<span class="n">opt_num_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Optimal number of batched input tokens for engine optimization.&quot;</span><span class="p">)</span>
<span class="n">max_prompt_embedding_table_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum size of prompt embedding table for prompt tuning.&quot;</span><span class="p">)</span>
<span class="n">kv_cache_type</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">KVCacheType</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED.&quot;</span>
<span class="p">)</span>
<span class="n">gather_context_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to gather logits during context phase.&quot;</span><span class="p">)</span>
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to gather logits during generation phase.&quot;</span><span class="p">)</span>
<span class="n">strongly_typed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to use strongly_typed.&quot;</span><span class="p">)</span>
<span class="n">force_num_profiles</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Force a specific number of optimization profiles. If None, auto-determined.&quot;</span>
<span class="p">)</span>
<span class="n">profiling_verbosity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="s1">&#39;layer_names_only&#39;</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Verbosity level for TensorRT profiling (&#39;layer_names_only&#39;, &#39;detailed&#39;, &#39;none&#39;).&quot;</span>
<span class="p">)</span>
<span class="n">enable_debug_output</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to enable debug output during building.&quot;</span><span class="p">)</span>
<span class="n">max_draft_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum length of draft tokens for speculative decoding.&quot;</span><span class="p">)</span>
<span class="n">speculative_decoding_mode</span><span class="p">:</span> <span class="n">SpeculativeDecodingMode</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.).&quot;</span>
<span class="p">)</span>
<span class="n">use_refit</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to enable engine refitting capabilities.&quot;</span><span class="p">)</span>
<span class="n">input_timing_cache</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Path to input timing cache file. If None, no input cache used.&quot;</span><span class="p">)</span>
<span class="n">output_timing_cache</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="s1">&#39;model.cache&#39;</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">&quot;Path to output timing cache file.&quot;</span><span class="p">)</span>
<span class="n">lora_config</span><span class="p">:</span> <span class="n">LoraConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default_factory</span><span class="o">=</span><span class="n">LoraConfig</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Configuration for LoRA (Low-Rank Adaptation) fine-tuning.&quot;</span><span class="p">)</span>
<span class="n">weight_sparsity</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to enable weight sparsity optimization.&quot;</span><span class="p">)</span>
<span class="n">weight_streaming</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to enable weight streaming for large models.&quot;</span><span class="p">)</span>
<span class="n">plugin_config</span><span class="p">:</span> <span class="n">PluginConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default_factory</span><span class="o">=</span><span class="n">PluginConfig</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Configuration for TensorRT LLM plugins.&quot;</span><span class="p">)</span>
<span class="n">use_strip_plan</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to use stripped plan for engine building.&quot;</span><span class="p">)</span>
<span class="n">max_encoder_input_len</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Maximum encoder input length for encoder-decoder models.&quot;</span><span class="p">)</span>
<span class="n">dry_run</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Whether to perform a dry run without actually building the engine.&quot;</span><span class="p">)</span>
<span class="n">visualize_network</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Path to save network visualization. If None, no visualization generated.&quot;</span>
<span class="p">)</span>
<span class="n">monitor_memory</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span><span class="s2">&quot;Whether to monitor memory usage during building.&quot;</span><span class="p">)</span>
<span class="n">use_mrope</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">description</span><span class="o">=</span>
<span class="s2">&quot;Whether to use Multi-RoPE (Rotary Position Embedding) optimization.&quot;</span><span class="p">)</span>
<span class="c1"># Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),</span>
<span class="c1"># we need to handle it given model architecture.</span>
<div class="viewcode-block" id="BuildConfig.update_kv_cache_type">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.update_kv_cache_type">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">update_kv_cache_type</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">paged_kv_cache_attr</span> <span class="o">=</span> <span class="s1">&#39;paged_state&#39;</span> <span class="k">if</span> <span class="n">model_architecture</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s1">&#39;MambaForCausalLM&#39;</span><span class="p">,</span> <span class="s1">&#39;RecurrentGemmaForCausalLM&#39;</span>
<span class="p">]</span> <span class="k">else</span> <span class="s1">&#39;paged_kv_cache&#39;</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">paged_kv_cache_val</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">assert</span> <span class="p">(</span><span class="n">paged_kv_cache_val</span> <span class="o">==</span> <span class="kc">True</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
<span class="n">paged_kv_cache_val</span> <span class="o">==</span> <span class="kc">False</span>
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span> <span class="k">if</span> <span class="n">paged_kv_cache_val</span> <span class="k">else</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">CONTINUOUS</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">==</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">PAGED</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">getattr</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="k">def</span><span class="w"> </span><span class="nf">override_attri</span><span class="p">(</span><span class="n">attr_name</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
<span class="n">val</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">)</span>
<span class="k">if</span> <span class="n">val</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">!=</span> <span class="n">value</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Overriding </span><span class="si">{</span><span class="n">attr_name</span><span class="si">}</span><span class="s1"> to </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="c1"># Init other paged kvcache attri to false. For RecurrentGemma, we only support paged_state and paged_kv_cache have</span>
<span class="c1"># the same values. All other models should only consume either of the value and set other to False.</span>
<span class="n">is_recurrent_gemma</span> <span class="o">=</span> <span class="n">model_architecture</span> <span class="o">==</span> <span class="s1">&#39;RecurrentGemmaForCausalLM&#39;</span>
<span class="k">if</span> <span class="n">paged_kv_cache_attr</span> <span class="o">==</span> <span class="s1">&#39;paged_state&#39;</span><span class="p">:</span>
<span class="n">override_attri</span><span class="p">(</span>
<span class="s1">&#39;paged_kv_cache&#39;</span><span class="p">,</span>
<span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">plugin_config</span><span class="p">,</span> <span class="n">paged_kv_cache_attr</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_recurrent_gemma</span> <span class="k">else</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">override_attri</span><span class="p">(</span><span class="s1">&#39;paged_state&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span></div>
<div class="viewcode-block" id="BuildConfig.from_json_file">
<a class="viewcode-back" href="../../llm-api/reference.html#tensorrt_llm.llmapi.BuildConfig.from_json_file">[docs]</a>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_file</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_file</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="k">return</span> <span class="n">BuildConfig</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span></div>
</div>
<span class="k">class</span><span class="w"> </span><span class="nc">EngineConfig</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pretrained_config</span><span class="p">:</span> <span class="s1">&#39;PretrainedConfig&#39;</span><span class="p">,</span>
<span class="n">build_config</span><span class="p">:</span> <span class="s1">&#39;BuildConfig&#39;</span><span class="p">,</span> <span class="n">version</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">pretrained_config</span> <span class="o">=</span> <span class="n">pretrained_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="o">=</span> <span class="n">build_config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">version</span> <span class="o">=</span> <span class="n">version</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_file</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_file</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">from_json_str</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_json_str</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config_str</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">config_str</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">PretrainedConfig</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">]),</span>
<span class="n">BuildConfig</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">[</span><span class="s1">&#39;build_config&#39;</span><span class="p">]),</span> <span class="n">config</span><span class="p">[</span><span class="s1">&#39;version&#39;</span><span class="p">])</span>
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">build_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">model_dump</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">&quot;json&quot;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;dry_run&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="c1"># Not an Engine Characteristic</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;visualize_network&#39;</span><span class="p">,</span>
<span class="kc">None</span><span class="p">)</span> <span class="c1"># Not an Engine Characteristic</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s1">&#39;version&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">version</span><span class="p">,</span>
<span class="s1">&#39;pretrained_config&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">(),</span>
<span class="s1">&#39;build_config&#39;</span><span class="p">:</span> <span class="n">build_config</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">class</span><span class="w"> </span><span class="nc">Engine</span><span class="p">:</span>
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">config</span><span class="p">:</span> <span class="n">EngineConfig</span><span class="p">,</span>
<span class="n">engine</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span>
<span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="p">{},</span>
<span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="o">=</span> <span class="n">engine</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="o">=</span> <span class="n">managed_weights</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">value</span><span class="o">.</span><span class="n">flags</span><span class="p">[</span><span class="s1">&#39;C_CONTIGUOUS&#39;</span><span class="p">]:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">save</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">lora_config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span>
<span class="n">lora_dirs</span> <span class="o">=</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span>
<span class="n">root_lora_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;lora&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">lora_dirs</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">lora_dir</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">lora_dirs</span><span class="p">):</span>
<span class="k">if</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_ckpt_source</span> <span class="o">==</span> <span class="s1">&#39;hf&#39;</span><span class="p">:</span>
<span class="n">target_lora_dir</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">root_lora_dir</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">target_lora_dir</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="s1">&#39;adapter_config.json&#39;</span><span class="p">),</span>
<span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">weight_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="s1">&#39;adapter_model.bin&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">weight_file</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">weight_file</span><span class="p">,</span> <span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">weight_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span>
<span class="s1">&#39;adapter_model.safetensors&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">weight_file</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copy2</span><span class="p">(</span><span class="n">weight_file</span><span class="p">,</span> <span class="n">target_lora_dir</span><span class="p">)</span>
<span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;lora/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="k">elif</span> <span class="n">lora_config</span><span class="o">.</span><span class="n">lora_ckpt_source</span> <span class="o">==</span> <span class="s1">&#39;nemo&#39;</span><span class="p">:</span>
<span class="n">target_lora_file</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">root_lora_dir</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">.nemo&quot;</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">copyfile</span><span class="p">(</span><span class="n">lora_dir</span><span class="p">,</span> <span class="n">target_lora_file</span><span class="p">)</span>
<span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;lora/</span><span class="si">{</span><span class="n">index</span><span class="si">}</span><span class="s2">.nemo&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isdir</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">):</span>
<span class="n">shutil</span><span class="o">.</span><span class="n">rmtree</span><span class="p">(</span><span class="n">root_lora_dir</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">config_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">to_dict</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">MIXED_PRECISION</span><span class="p">:</span>
<span class="n">quant_dict</span> <span class="o">=</span> <span class="p">{</span>
<span class="s1">&#39;version&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">version</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">quant_dict</span><span class="o">.</span><span class="n">update</span><span class="p">(</span>
<span class="n">config_dict</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">][</span><span class="s1">&#39;quantization&#39;</span><span class="p">])</span>
<span class="n">config_dict</span><span class="p">[</span><span class="s1">&#39;pretrained_config&#39;</span><span class="p">][</span><span class="s1">&#39;quantization&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span>
<span class="s1">&#39;quantized_layers&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;quant_cfg.json&#39;</span><span class="p">),</span>
<span class="s2">&quot;w&quot;</span><span class="p">,</span>
<span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">quant_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="bp">cls</span><span class="o">=</span><span class="n">ConfigEncoder</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;config.json&#39;</span><span class="p">),</span>
<span class="s2">&quot;w&quot;</span><span class="p">,</span>
<span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">json</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">config_dict</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="bp">cls</span><span class="o">=</span><span class="n">ConfigEncoder</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">serialize_engine</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="p">,</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
<span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="si">}</span><span class="s1">.engine&#39;</span><span class="p">))</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">fn</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
<span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span><span class="si">}</span><span class="s1">_managed_weights.safetensors&#39;</span>
<span class="p">)</span>
<span class="n">serialize_managed_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">managed_weights</span><span class="p">,</span> <span class="n">fn</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_dir</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s1">.engine&#39;</span><span class="p">),</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">engine_buffer</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
<span class="n">mw_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span>
<span class="sa">f</span><span class="s1">&#39;rank</span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s1">_managed_weights.safetensors&#39;</span><span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="n">deserialize_managed_weights</span><span class="p">(</span>
<span class="n">mw_path</span><span class="p">)</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mw_path</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span>
<span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">,</span> <span class="s1">&#39;config.json&#39;</span><span class="p">))</span>
<span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">set_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span><span class="w"> </span><span class="nf">from_buffer</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span>
<span class="n">engine_buffer</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">trt</span><span class="o">.</span><span class="n">IHostMemory</span><span class="p">,</span> <span class="nb">bytes</span><span class="p">],</span>
<span class="n">json_config_str</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_str</span><span class="p">(</span><span class="n">json_config_str</span><span class="p">)</span>
<span class="n">config</span><span class="o">.</span><span class="n">pretrained_config</span><span class="o">.</span><span class="n">set_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">engine_buffer</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">get_engine_version</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="nb">str</span><span class="p">]:</span>
<span class="n">engine_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">engine_dir</span><span class="p">)</span>
<span class="n">config_path</span> <span class="o">=</span> <span class="n">engine_dir</span> <span class="o">/</span> <span class="s2">&quot;config.json&quot;</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">config_path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="k">if</span> <span class="s1">&#39;version&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">None</span>
<span class="k">return</span> <span class="n">config</span><span class="p">[</span><span class="s1">&#39;version&#39;</span><span class="p">]</span>
<span class="k">def</span><span class="w"> </span><span class="nf">optimize_model_with_config</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">PretrainedModel</span><span class="p">,</span>
<span class="n">build_config</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">):</span>
<span class="n">gemm_swiglu_plugin</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_swiglu_plugin</span>
<span class="n">low_latency_gemm_swiglu_plugin</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">low_latency_gemm_swiglu_plugin</span>
<span class="k">if</span> <span class="n">gemm_swiglu_plugin</span> <span class="ow">or</span> <span class="n">low_latency_gemm_swiglu_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fused_mlp</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;GemmSwiGLU plugin requires --use_fused_mlp flag&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">gemm_swiglu_plugin</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s2">&quot;fp8&quot;</span>
<span class="p">]</span> <span class="ow">and</span> <span class="n">low_latency_gemm_swiglu_plugin</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;fp8&quot;</span><span class="p">]:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;GemmSwiGLU plugin currently has limited support: fp8 only, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;got: </span><span class="si">{</span><span class="n">gemm_swiglu_plugin</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="sa">f</span><span class="s2">&quot;got: </span><span class="si">{</span><span class="n">low_latency_gemm_swiglu_plugin</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">model</span><span class="o">.</span><span class="n">use_lora</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="p">)</span>
<span class="n">is_enc_dec</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;EncoderModel&quot;</span><span class="p">,</span> <span class="s2">&quot;DecoderModel&quot;</span><span class="p">]</span>
<span class="c1"># FusedMLP does not support RecurrentGemma FP8 currently.</span>
<span class="n">is_recurrent_gemma</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="ow">in</span> <span class="p">[</span>
<span class="s2">&quot;RecurrentGemmaForCausalLM&quot;</span>
<span class="p">]</span>
<span class="n">is_fp8</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">optimize_model</span><span class="p">(</span>
<span class="n">model</span><span class="p">,</span>
<span class="n">share_embedding_table</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">use_ootb_moe</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">moe_plugin</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">use_fused_mlp</span><span class="o">=</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fused_mlp</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">is_enc_dec</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="p">(</span><span class="n">is_recurrent_gemma</span> <span class="ow">and</span> <span class="n">is_fp8</span><span class="p">)),</span>
<span class="n">gemm_swiglu_plugin_dtype</span><span class="o">=</span><span class="n">gemm_swiglu_plugin</span><span class="p">,</span>
<span class="n">low_latency_gemm_swiglu_plugin_dtype</span><span class="o">=</span><span class="n">low_latency_gemm_swiglu_plugin</span><span class="p">,</span>
<span class="n">use_fused_rg_lru</span><span class="o">=</span><span class="n">is_recurrent_gemma</span><span class="p">,</span>
<span class="n">use_unfused_qkv_gemm</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">use_prompt_tuning</span><span class="o">=</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">),</span>
<span class="n">use_lora</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">max_lora_rank</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">max_lora_rank</span><span class="p">,</span>
<span class="n">use_fp8_context_fmha</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span>
<span class="p">]</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">),</span>
<span class="n">fuse_fp4_quant</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">fuse_fp4_quant</span><span class="p">,</span>
<span class="n">use_optimize_cross_qkv</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">use_dora</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">dora_plugin</span><span class="p">)</span>
<span class="k">if</span> <span class="n">is_enc_dec</span><span class="p">:</span>
<span class="n">model</span><span class="o">.</span><span class="n">precompute_relative_attention_bias</span><span class="p">(</span><span class="n">build_config</span><span class="p">)</span>
<span class="k">return</span> <span class="n">model</span>
<span class="k">def</span><span class="w"> </span><span class="nf">_init_max_seq_len</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor</span>
<span class="sd"> Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Extract rotary scaling which will be used for checks and default value of max_seq_len</span>
<span class="n">rotary_scaling</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">model_config</span><span class="p">,</span> <span class="s2">&quot;rotary_scaling&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rotary_scaling</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">rotary_type</span> <span class="o">=</span> <span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;type&#39;</span><span class="p">,</span>
<span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;rope_type&#39;</span><span class="p">))</span>
<span class="n">rotary_factor</span> <span class="o">=</span> <span class="n">rotary_scaling</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
<span class="s1">&#39;factor&#39;</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span> <span class="k">if</span> <span class="n">rotary_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;su&quot;</span><span class="p">,</span> <span class="s2">&quot;longrope&quot;</span><span class="p">,</span>
<span class="s2">&quot;llama3&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">rotary_factor</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;EncoderModel&quot;</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is not specified for EncoderModel, using --max_input_len.&#39;</span>
<span class="p">)</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">==</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;EncoderModel should have same --max_input_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s2">) and --max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2">).&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># Step 1: Find the upper bound of max_seq_len</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="mi">2048</span>
<span class="k">if</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span>
<span class="c1"># Step 2: Scale max_seq_len with rotary scaling</span>
<span class="k">if</span> <span class="n">rotary_factor</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">deduced_max_seq_len</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">deduced_max_seq_len</span> <span class="o">*</span> <span class="n">rotary_factor</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is scaled to </span><span class="si">{</span><span class="n">deduced_max_seq_len</span><span class="si">}</span><span class="s1"> by rotary scaling </span><span class="si">{</span><span class="n">rotary_factor</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="c1"># Step 3: Assign the new max_seq_len</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">deduced_max_seq_len</span><span class="p">)</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len is not specified, using deduced value </span><span class="si">{</span><span class="n">deduced_max_seq_len</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span> <span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> \
<span class="ow">and</span> <span class="n">model_config</span><span class="o">.</span><span class="n">position_embedding_type</span> <span class="o">!=</span> <span class="n">PositionEmbeddingType</span><span class="o">.</span><span class="n">relative</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">&gt;</span> <span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">rotary_factor</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_seq_len </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1"> is larger than max_position_embeddings </span><span class="si">{</span><span class="n">model_config</span><span class="o">.</span><span class="n">max_position_embeddings</span><span class="si">}</span><span class="s1"> * rotary scaling </span><span class="si">{</span><span class="n">rotary_factor</span><span class="si">}</span><span class="s1">, &#39;</span>
<span class="s1">&#39;the model accuracy might be affected&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">&gt;</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_input_len is </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s1"> is larger than max_seq_len </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">, clipping it to max_seq_len&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
<span class="c1"># Check and may modify max_num_tokens and opt_num_tokens (need to happen after max_seq_len is deduced)</span>
<span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">opt_num_tokens</span> <span class="o">=</span> <span class="n">check_max_num_tokens</span><span class="p">(</span>
<span class="n">max_num_tokens</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span>
<span class="n">opt_num_tokens</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span><span class="p">,</span>
<span class="n">max_batch_size</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="n">max_input_len</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="n">max_seq_len</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
<span class="n">max_beam_width</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="n">remove_input_padding</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span><span class="p">,</span>
<span class="n">enable_context_fmha</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">,</span>
<span class="n">tokens_per_block</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">tokens_per_block</span><span class="p">,</span>
<span class="n">multiple_profiles</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">multiple_profiles</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span> <span class="o">=</span> <span class="n">max_num_tokens</span><span class="p">,</span> <span class="n">opt_num_tokens</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">remove_input_padding</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;padding removal and fMHA are both enabled, max_input_len is not required and will be ignored&#39;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;padding removal and fMHA aren</span><span class="se">\&#39;</span><span class="s1">t both enabled, max_input_len is required&#39;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">&lt;=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span> <span class="s1">&#39;max_input_len should not be larger than max_seq_len&#39;</span>
<span class="k">def</span><span class="w"> </span><span class="nf">serialize_managed_weights</span><span class="p">(</span><span class="n">managed_weights</span><span class="p">:</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">],</span>
<span class="n">path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">,</span>
<span class="n">metadata</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">header</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">if</span> <span class="n">metadata</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">header</span><span class="p">[</span><span class="s2">&quot;__metadata__&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">metadata</span>
<span class="n">begin</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">size</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">size</span> <span class="o">*</span> <span class="n">value</span><span class="o">.</span><span class="n">itemsize</span>
<span class="k">if</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F32&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F16&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np_bfloat16</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;BF16&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np_float8</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;F8_E4M3&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int64</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I64&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I32&quot;</span>
<span class="k">elif</span> <span class="n">value</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="n">np</span><span class="o">.</span><span class="n">int8</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="s2">&quot;I8&quot;</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported dtype: </span><span class="si">{</span><span class="n">value</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">header</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;dtype&quot;</span><span class="p">:</span> <span class="n">dtype</span><span class="p">,</span>
<span class="s2">&quot;shape&quot;</span><span class="p">:</span> <span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span>
<span class="s2">&quot;data_offsets&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">begin</span><span class="p">,</span> <span class="n">begin</span> <span class="o">+</span> <span class="n">size</span><span class="p">],</span>
<span class="p">}</span>
<span class="n">begin</span> <span class="o">+=</span> <span class="n">size</span>
<span class="n">header_json</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">header</span><span class="p">)</span>
<span class="n">header_json_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">header_json</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;wb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Serializing </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">managed_weights</span><span class="p">)</span><span class="si">}</span><span class="s2"> managed weights to </span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">...&quot;</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">header_json_len</span><span class="o">.</span><span class="n">to_bytes</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">byteorder</span><span class="o">=</span><span class="s2">&quot;little&quot;</span><span class="p">))</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">header_json</span><span class="o">.</span><span class="n">encode</span><span class="p">())</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">managed_weights</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Serializing managed weight: </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">buf</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">data</span>
<span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">buf</span><span class="p">)</span>
<span class="k">def</span><span class="w"> </span><span class="nf">deserialize_managed_weights</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">|</span> <span class="n">Path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]:</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;rb&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">header_json_len</span> <span class="o">=</span> <span class="nb">int</span><span class="o">.</span><span class="n">from_bytes</span><span class="p">(</span><span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span> <span class="n">byteorder</span><span class="o">=</span><span class="s2">&quot;little&quot;</span><span class="p">)</span>
<span class="n">header_json</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">header_json_len</span><span class="p">)</span><span class="o">.</span><span class="n">decode</span><span class="p">()</span>
<span class="n">header</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">header_json</span><span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">info</span> <span class="ow">in</span> <span class="n">header</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;dtype&quot;</span><span class="p">]</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;shape&quot;</span><span class="p">]</span>
<span class="n">data_offsets</span> <span class="o">=</span> <span class="n">info</span><span class="p">[</span><span class="s2">&quot;data_offsets&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F32&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float32</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F16&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float16</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;BF16&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np_bfloat16</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;F8_E4M3&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np_float8</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;I64&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">int64</span>
<span class="k">elif</span> <span class="n">dtype</span> <span class="o">==</span> <span class="s2">&quot;I32&quot;</span><span class="p">:</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">int32</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported dtype: </span><span class="si">{</span><span class="n">dtype</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">f</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="n">data_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">header_json_len</span> <span class="o">+</span> <span class="mi">8</span><span class="p">)</span>
<span class="n">buf</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">data_offsets</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">data_offsets</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">frombuffer</span><span class="p">(</span><span class="n">buf</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">shape</span><span class="p">)</span>
<span class="n">managed_weights</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="k">return</span> <span class="n">managed_weights</span>
<span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">PretrainedModel</span><span class="p">,</span> <span class="n">build_config</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Engine</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&#39;&#39;&#39;Build engine from given model and optimization options specified in the build_config</span>
<span class="sd"> WARNING: this function may change the given model object state in some optimization passes</span>
<span class="sd"> to avoid cloning a model since normally the LLM models consumes large memory.</span>
<span class="sd"> Create a new fresh model object if you need to build with different options.</span>
<span class="sd"> &#39;&#39;&#39;</span>
<span class="n">tic</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="c1"># avoid changing the input config</span>
<span class="n">build_config</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">model_copy</span><span class="p">(</span><span class="n">deep</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">update_kv_cache_type</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span><span class="p">)</span>
<span class="n">_init_max_seq_len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">streamingllm</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is disabled because StreamingLLM is not supported when enabling paged KV context FMHA.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span>
<span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;LlamaForCausalLM&quot;</span>
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;Gemma2ForCausalLM&quot;</span>
<span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;MedusaForCausalLM&quot;</span><span class="p">)):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding reduce_fusion to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding user_buffer to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">user_buffer</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">norm_quant_fusion</span> <span class="ow">and</span> <span class="p">(</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">reduce_fusion</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;LlamaForCausalLM&quot;</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">!=</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span><span class="p">):</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;Overriding norm_quant_fusion to False&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">norm_quant_fusion</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span> <span class="ow">or</span> \
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">kv_cache_quant_algo</span> <span class="o">==</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">strongly_typed</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;max_draft_len&#39;</span><span class="p">):</span>
<span class="c1"># If model.config has &#39;max_draft_len&#39; but build_config not specified,</span>
<span class="c1"># use the value of model.config.max_draft_len to set the value of build_config.max_draft_len</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_draft_len</span>
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;redrafter_num_beams&#39;</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;redrafter_draft_len_per_beam&#39;</span><span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">redrafter_num_beams</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">redrafter_draft_len_per_beam</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EXPLICIT_DRAFT_TOKENS</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s1">&#39;speculative_decoding_mode is not EXPLICIT_DRAFT_TOKENS for ReDrafter model. Overwriting speculative_decoding_mode&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EXPLICIT_DRAFT_TOKENS</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Increasing max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="sa">f</span><span class="s1">&#39;by max_draft_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="s1">&#39;to account for speculative decoding implementation specifics. &#39;</span>
<span class="s1">&#39;Maximum number of generated tokens remains the same. &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_seq_len is set to </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span><span class="p">:</span>
<span class="k">assert</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;num_eagle_layers&#39;</span><span class="p">)</span>
<span class="n">num_eagle_layers</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_eagle_layers</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;Increasing max_seq_len (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="sa">f</span><span class="s1">&#39;by num_eagle_layers (</span><span class="si">{</span><span class="n">num_eagle_layers</span><span class="si">}</span><span class="s1">) &#39;</span>
<span class="s1">&#39;to account for EAGLE implementation specifics. &#39;</span>
<span class="s1">&#39;Maximum number of generated tokens remains the same. &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_seq_len is set to </span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">num_eagle_layers</span><span class="si">}</span><span class="s1">&#39;</span>
<span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">+=</span> <span class="n">num_eagle_layers</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">!=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NONE</span><span class="p">:</span>
<span class="n">num_tokens</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">*</span> <span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">+</span>
<span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">&lt;</span> <span class="n">num_tokens</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s1">&#39;max_num_tokens (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="si">}</span><span class="s1">) is smaller than &#39;</span>
<span class="s1">&#39;max_batch_size * (max_draft_len + 1) = &#39;</span>
<span class="sa">f</span><span class="s1">&#39;(</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s1"> * (</span><span class="si">{</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s1"> + 1)). &#39;</span>
<span class="sa">f</span><span class="s1">&#39;New max_num_tokens is set to </span><span class="si">{</span><span class="n">num_tokens</span><span class="si">}</span><span class="s1">.&#39;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">=</span> <span class="n">num_tokens</span>
<span class="c1"># Logics to control paged_context_fmha and fp8_context_fmha</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Context FMHA is disabled, FP8 Context FMHA and Paged Context FMHA are disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">quant_algo</span> <span class="ow">in</span> <span class="p">[</span>
<span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">W4A8_AWQ</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">NVFP4</span>
<span class="p">]:</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Context FMHA is disabled because it must be used together with the fp8 quantization workflow.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">89</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 context FMHA is disabled because it is only supported on Ada and Hopper Arch.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because FP8 context FMHA is disabled.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Paged Context FMHA is disabled because it must be used together with fp8 KV Cache.&quot;</span>
<span class="p">)</span>
<span class="k">elif</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_kv_cache</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_fp8_context_fmha</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;FP8 Context FMHA is enabled to support FP8 Paged Context FMHA.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is disabled because it doesn&#39;t work with int8 kv cache currently.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">100</span> <span class="ow">and</span> <span class="n">get_sm_version</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">120</span><span class="p">:</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_int8_weight_only</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_int4_weight_only</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;INT8/INT4 quantization is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_and_weight_quant</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">&quot;SmoothQuant is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_channel_scaling</span><span class="p">(</span>
<span class="p">)</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_token_dynamic_scaling</span><span class="p">():</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Per-channel or per-token scaling is not supported on SM&gt;=100.&quot;</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">optimize_model_with_config</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">build_config</span><span class="p">)</span>
<span class="n">builder</span> <span class="o">=</span> <span class="n">Builder</span><span class="p">()</span>
<span class="n">builder_config</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">create_builder_config</span><span class="p">(</span>
<span class="n">precision</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span>
<span class="n">use_refit</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">use_refit</span><span class="p">,</span>
<span class="n">timing_cache</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">input_timing_cache</span><span class="p">,</span>
<span class="n">int8</span><span class="o">=</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_or_weight_quant</span><span class="p">()</span>
<span class="ow">and</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_group_scaling</span><span class="p">())</span>
<span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_int8_kv_cache</span><span class="p">(),</span>
<span class="n">strongly_typed</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">strongly_typed</span><span class="p">,</span>
<span class="n">force_num_profiles</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">force_num_profiles</span><span class="p">,</span>
<span class="n">profiling_verbosity</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">profiling_verbosity</span><span class="p">,</span>
<span class="n">quant_mode</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="p">,</span>
<span class="n">use_strip_plan</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">use_strip_plan</span><span class="p">,</span>
<span class="n">weight_sparsity</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">weight_sparsity</span><span class="p">,</span>
<span class="n">weight_streaming</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">weight_streaming</span><span class="p">,</span>
<span class="n">monitor_memory</span><span class="o">=</span><span class="n">build_config</span><span class="o">.</span><span class="n">monitor_memory</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">network</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">create_network</span><span class="p">()</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span>
<span class="n">use_weight_only</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_weight_only</span><span class="p">()</span>
<span class="n">per_group</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_per_group_scaling</span><span class="p">()</span>
<span class="n">use_smooth_quant</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_act_and_weight_quant</span><span class="p">()</span>
<span class="n">use_qserve</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">is_qserve_w4a8</span><span class="p">()</span>
<span class="n">use_fp8_rowwise</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">()</span>
<span class="n">disable_weight_only_quant_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">disable_weight_only_quant_plugin</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span>
<span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="s1">&#39;disable_weight_only_quant_plugin&#39;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
<span class="n">use_fp8_rowwise</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_fp8_rowwise</span><span class="p">()</span>
<span class="n">use_fp4_gemm</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quant_mode</span><span class="o">.</span><span class="n">has_nvfp4</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_fp4_gemm</span> <span class="ow">and</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">_explicitly_disable_gemm_plugin</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="s1">&#39;NVFP4 quantization detected, by default enabling NVFP4 GEMM plugin. To use OOTB GEMM, please explicitly set gemm_plugin to &quot;disable&quot;&#39;</span>
<span class="p">)</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">gemm_plugin</span> <span class="o">=</span> <span class="s2">&quot;nvfp4&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span><span class="p">:</span>
<span class="k">if</span> <span class="n">use_weight_only</span> <span class="ow">and</span> <span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;Manage weights of weight only quant works only with plugin currently.&quot;</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">use_weight_only</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">disable_weight_only_quant_plugin</span><span class="p">:</span>
<span class="k">if</span> <span class="n">per_group</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">weight_only_groupwise_quant_matmul_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">weight_only_quant_matmul_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span>
<span class="k">if</span> <span class="n">use_smooth_quant</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">_use_plugin_sq</span> <span class="ow">and</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">smooth_quant_plugins</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_smooth_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_qserve</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_qserve_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_fp8_rowwise</span><span class="p">:</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_fp8_rowwise_quant_plugins</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
<span class="n">nccl_plugin</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span> <span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">world_size</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">set_nccl_plugin</span><span class="p">(</span><span class="n">nccl_plugin</span><span class="p">)</span>
<span class="k">with</span> <span class="n">net_guard</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="c1"># Prepare</span>
<span class="n">network</span><span class="o">.</span><span class="n">set_named_parameters</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">())</span>
<span class="c1"># Forward</span>
<span class="n">prepare_input_args</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;max_batch_size&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="s2">&quot;max_input_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">,</span>
<span class="s2">&quot;max_seq_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
<span class="s2">&quot;use_cache&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">kv_cache_type</span> <span class="o">!=</span> <span class="n">KVCacheType</span><span class="o">.</span><span class="n">DISABLED</span><span class="p">,</span>
<span class="s2">&quot;max_beam_width&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
<span class="s2">&quot;max_num_tokens&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span>
<span class="s2">&quot;opt_num_tokens&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">opt_num_tokens</span><span class="p">,</span>
<span class="s2">&quot;prompt_embedding_table_size&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span><span class="p">,</span>
<span class="s2">&quot;max_draft_len&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="p">,</span>
<span class="s2">&quot;speculative_decoding_draft_tokens_external&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span>
<span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">DRAFT_TOKENS_EXTERNAL</span><span class="p">,</span>
<span class="s2">&quot;gather_context_logits&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">gather_context_logits</span><span class="p">,</span>
<span class="s2">&quot;lora_target_modules&quot;</span><span class="p">:</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_target_modules</span>
<span class="p">}</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;DecoderModel&quot;</span> <span class="ow">or</span> <span class="s2">&quot;mllama&quot;</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span>
<span class="p">):</span>
<span class="n">prepare_input_args</span><span class="p">[</span><span class="s2">&quot;max_seq_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;max_decoder_input_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;max_encoder_input_len&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_encoder_input_len</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;WhisperEncoder&quot;</span><span class="p">:</span>
<span class="n">prepare_input_args</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;max_batch_size&quot;</span><span class="p">:</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
<span class="p">}</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;spec_decoding_is_generation_length_variable&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">&lt;=</span> <span class="mi">512</span><span class="p">,</span> <span class="s2">&quot;Max batch size &gt; 512 is not supported for EAGLE&quot;</span>
<span class="k">assert</span> <span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">&lt;=</span> <span class="mi">256</span><span class="p">,</span> <span class="s2">&quot;Max draft len &gt; 256 is not supported for EAGLE&quot;</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">LOOKAHEAD_DECODING</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s2">&quot;spec_decoding_is_generation_length_variable&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;Qwen2VLForConditionalGeneration&quot;</span> <span class="ow">or</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">==</span> <span class="s2">&quot;Qwen2VLModel&quot;</span><span class="p">:</span>
<span class="n">prepare_input_args</span><span class="p">[</span>
<span class="s1">&#39;mrope_rotary_cos_sin_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">max_position_embeddings</span> <span class="o">*</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">rotary_embedding_dim</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">==</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span><span class="p">:</span>
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
<span class="s2">&quot;Paged Context FMHA is required for EAGLE. Turning it on&quot;</span><span class="p">)</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">use_paged_context_fmha</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">prepare_inputs</span><span class="p">(</span><span class="o">**</span><span class="n">prepare_input_args</span><span class="p">)</span>
<span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">enable_debug_output</span><span class="p">:</span>
<span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">named_network_outputs</span><span class="p">():</span>
<span class="n">network</span><span class="o">.</span><span class="n">_mark_output</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">str_dtype_to_trt</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype</span><span class="p">))</span>
<span class="k">if</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">architecture</span> <span class="o">!=</span> <span class="s2">&quot;DecoderModel&quot;</span><span class="p">:</span>
<span class="n">optimize</span><span class="p">(</span><span class="n">network</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">with</span> <span class="n">net_guard</span><span class="p">(</span><span class="n">network</span><span class="p">):</span>
<span class="n">network</span><span class="o">.</span><span class="n">to_onnx</span><span class="p">(</span><span class="n">build_config</span><span class="o">.</span><span class="n">visualize_network</span><span class="p">)</span>
<span class="c1"># Network -&gt; Engine</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Total time of constructing network from module object </span><span class="si">{</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="o">-</span><span class="n">tic</span><span class="si">}</span><span class="s2"> seconds&quot;</span>
<span class="p">)</span>
<span class="n">managed_weights</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">if</span> <span class="n">network</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span> <span class="k">else</span> <span class="kc">None</span>
<span class="n">engine</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">dry_run</span> <span class="k">else</span> <span class="n">builder</span><span class="o">.</span><span class="n">build_engine</span><span class="p">(</span>
<span class="n">network</span><span class="p">,</span> <span class="n">builder_config</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
<span class="n">engine_config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="p">,</span> <span class="n">build_config</span><span class="p">,</span> <span class="n">__version__</span><span class="p">)</span>
<span class="k">if</span> <span class="n">build_config</span><span class="o">.</span><span class="n">output_timing_cache</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">ok</span> <span class="o">=</span> <span class="n">builder</span><span class="o">.</span><span class="n">save_timing_cache</span><span class="p">(</span><span class="n">builder_config</span><span class="p">,</span>
<span class="n">build_config</span><span class="o">.</span><span class="n">output_timing_cache</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">ok</span><span class="p">,</span> <span class="s2">&quot;Failed to save timing cache.&quot;</span>
<span class="kn">import</span><span class="w"> </span><span class="nn">psutil</span>
<span class="c1"># Get the current process</span>
<span class="n">current_process</span> <span class="o">=</span> <span class="n">psutil</span><span class="o">.</span><span class="n">Process</span><span class="p">()</span>
<span class="c1"># Get resource usage for the current process (self)</span>
<span class="n">rusage_s</span> <span class="o">=</span> <span class="n">current_process</span><span class="o">.</span><span class="n">memory_info</span><span class="p">()</span>
<span class="c1"># Get resource usage for all child processes</span>
<span class="n">children</span> <span class="o">=</span> <span class="n">current_process</span><span class="o">.</span><span class="n">children</span><span class="p">(</span><span class="n">recursive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">rusage_c</span> <span class="o">=</span> <span class="p">[</span><span class="n">child</span><span class="o">.</span><span class="n">memory_info</span><span class="p">()</span> <span class="k">for</span> <span class="n">child</span> <span class="ow">in</span> <span class="n">children</span><span class="p">]</span>
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Build phase peak memory: </span><span class="si">{</span><span class="n">rusage_s</span><span class="o">.</span><span class="n">rss</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> MB, children: </span><span class="si">{</span><span class="nb">sum</span><span class="p">([</span><span class="n">ru</span><span class="o">.</span><span class="n">rss</span><span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="n">ru</span><span class="w"> </span><span class="ow">in</span><span class="w"> </span><span class="n">rusage_c</span><span class="p">])</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="mi">1024</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> MB&quot;</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">Engine</span><span class="p">(</span><span class="n">engine_config</span><span class="p">,</span> <span class="n">engine</span><span class="p">,</span> <span class="n">managed_weights</span><span class="p">)</span>
</pre></div>
</article>
<footer class="prev-next-footer d-print-none">
<div class="prev-next-area">
</div>
</footer>
</div>
<div class="bd-sidebar-secondary"></div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script defer src="../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
<script defer src="../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item">
<a class="footer-brand logo" href="https://www.nvidia.com">
<img src="../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
<img src="../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
</a></div>
<div class="footer-item">
<div class="footer-links">
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Your Privacy Choices</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
</div>
</div>
<div class="footer-item">
<p class="copyright">
Copyright © 2025, NVidia.
<br/>
</p>
</div>
<div class="footer-item">
<div class="extra_footer">
<p>Last updated on January 04, 2026.</p>
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/a65b0d4">a65b0d4</a>.</p>
</div></div>
</div>
</div>
</footer>
</body>
</html>