mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
4027 lines
483 KiB
HTML
4027 lines
483 KiB
HTML
|
||
|
||
<!DOCTYPE html>
|
||
|
||
|
||
<html lang="en" data-content_root="../../../" >
|
||
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||
<title>tensorrt_llm.llmapi.llm_args — TensorRT LLM</title>
|
||
|
||
|
||
|
||
<script data-cfasync="false">
|
||
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
|
||
document.documentElement.dataset.theme = localStorage.getItem("theme") || "";
|
||
</script>
|
||
<!--
|
||
this give us a css class that will be invisible only if js is disabled
|
||
-->
|
||
<noscript>
|
||
<style>
|
||
.pst-js-only { display: none !important; }
|
||
|
||
</style>
|
||
</noscript>
|
||
|
||
<!-- Loaded before other Sphinx assets -->
|
||
<link href="../../../_static/styles/theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf" rel="stylesheet" />
|
||
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css?v=8f2a1f02" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/styles/nvidia-sphinx-theme.css?v=933278ad" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css?v=76b2166b" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/autodoc_pydantic.css" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/togglebutton.css?v=13237357" />
|
||
<link rel="stylesheet" type="text/css" href="../../../_static/custom.css?v=19d20f17" />
|
||
|
||
<!-- So that users can add custom icons -->
|
||
<script src="../../../_static/scripts/fontawesome.js?digest=8878045cc6db502f8baf"></script>
|
||
<!-- Pre-loaded scripts that we'll load fully later -->
|
||
<link rel="preload" as="script" href="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf" />
|
||
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf" />
|
||
|
||
|
||
|
||
<script src="../../../_static/documentation_options.js?v=5929fcd5"></script>
|
||
<script src="../../../_static/doctools.js?v=9a2dae69"></script>
|
||
<script src="../../../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||
<script src="../../../_static/clipboard.min.js?v=a7894cd8"></script>
|
||
<script src="../../../_static/copybutton.js?v=65e89d2a"></script>
|
||
<script>let toggleHintShow = 'Click to show';</script>
|
||
<script>let toggleHintHide = 'Click to hide';</script>
|
||
<script>let toggleOpenOnPrint = 'true';</script>
|
||
<script src="../../../_static/togglebutton.js?v=4a39c7ea"></script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>var togglebuttonSelector = '.toggle, .admonition.dropdown';</script>
|
||
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/tensorrt_llm/llmapi/llm_args';</script>
|
||
<script>
|
||
DOCUMENTATION_OPTIONS.theme_version = '0.16.1';
|
||
DOCUMENTATION_OPTIONS.theme_switcher_json_url = './_static/switcher.json';
|
||
DOCUMENTATION_OPTIONS.theme_switcher_version_match = '1.2.0rc4';
|
||
DOCUMENTATION_OPTIONS.show_version_warning_banner =
|
||
false;
|
||
</script>
|
||
|
||
<link rel="icon" href="../../../_static/favicon.png"/>
|
||
|
||
<link rel="index" title="Index" href="../../../genindex.html" />
|
||
<link rel="search" title="Search" href="../../../search.html" />
|
||
|
||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||
<meta name="docsearch:language" content="en"/>
|
||
<meta name="docsearch:version" content="1.2.0rc4" />
|
||
|
||
|
||
</head>
|
||
|
||
|
||
|
||
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
|
||
|
||
|
||
|
||
<div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div>
|
||
|
||
|
||
|
||
<div id="pst-scroll-pixel-helper"></div>
|
||
|
||
<button type="button" class="btn rounded-pill" id="pst-back-to-top">
|
||
<i class="fa-solid fa-arrow-up"></i>Back to top</button>
|
||
|
||
|
||
<dialog id="pst-search-dialog">
|
||
|
||
<form class="bd-search d-flex align-items-center"
|
||
action="../../../search.html"
|
||
method="get">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<input type="search"
|
||
class="form-control"
|
||
name="q"
|
||
placeholder="Search the docs ..."
|
||
aria-label="Search the docs ..."
|
||
autocomplete="off"
|
||
autocorrect="off"
|
||
autocapitalize="off"
|
||
spellcheck="false"/>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
|
||
</form>
|
||
</dialog>
|
||
|
||
<div class="pst-async-banner-revealer d-none">
|
||
<aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside>
|
||
</div>
|
||
|
||
|
||
<header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none">
|
||
<div class="bd-header__inner bd-page-width">
|
||
<button class="pst-navbar-icon sidebar-toggle primary-toggle" aria-label="Site navigation">
|
||
<span class="fa-solid fa-bars"></span>
|
||
</button>
|
||
|
||
|
||
<div class="col-lg-3 navbar-header-items__start">
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../../../index.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
|
||
|
||
|
||
<p class="title logo__title">TensorRT LLM</p>
|
||
|
||
</a></div>
|
||
|
||
</div>
|
||
|
||
<div class="col-lg-9 navbar-header-items">
|
||
|
||
<div class="me-auto navbar-header-items__center">
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
<div class="version-switcher__container dropdown pst-js-only">
|
||
<button id="pst-version-switcher-button-2"
|
||
type="button"
|
||
class="version-switcher__button btn btn-sm dropdown-toggle"
|
||
data-bs-toggle="dropdown"
|
||
aria-haspopup="listbox"
|
||
aria-controls="pst-version-switcher-list-2"
|
||
aria-label="Version switcher list"
|
||
>
|
||
Choose version <!-- this text may get changed later by javascript -->
|
||
<span class="caret"></span>
|
||
</button>
|
||
<div id="pst-version-switcher-list-2"
|
||
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
||
role="listbox" aria-labelledby="pst-version-switcher-button-2">
|
||
<!-- dropdown will be populated by javascript on page load -->
|
||
</div>
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="navbar-header-items__end">
|
||
|
||
<div class="navbar-item navbar-persistent--container">
|
||
|
||
|
||
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
</div>
|
||
|
||
|
||
<div class="navbar-item">
|
||
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
<div class="navbar-persistent--mobile">
|
||
|
||
<button class="btn search-button-field search-button__button pst-js-only" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="fa-solid fa-magnifying-glass"></i>
|
||
<span class="search-button__default-text">Search</span>
|
||
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span>
|
||
</button>
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</header>
|
||
|
||
|
||
<div class="bd-container">
|
||
<div class="bd-container__inner bd-page-width">
|
||
|
||
|
||
|
||
<dialog id="pst-primary-sidebar-modal"></dialog>
|
||
<div id="pst-primary-sidebar" class="bd-sidebar-primary bd-sidebar">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<a class="navbar-brand logo" href="../../../index.html">
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-blk-for-screen.svg" class="logo__image only-light" alt="TensorRT LLM - Home"/>
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-wht-for-screen.svg" class="logo__image only-dark pst-js-only" alt="TensorRT LLM - Home"/>
|
||
|
||
|
||
<p class="title logo__title">TensorRT LLM</p>
|
||
|
||
</a>
|
||
|
||
|
||
|
||
<div class="sidebar-header-items sidebar-primary__section">
|
||
|
||
|
||
<div class="sidebar-header-items__center">
|
||
|
||
|
||
|
||
<div class="navbar-item">
|
||
|
||
|
||
<div class="version-switcher__container dropdown pst-js-only">
|
||
<button id="pst-version-switcher-button-3"
|
||
type="button"
|
||
class="version-switcher__button btn btn-sm dropdown-toggle"
|
||
data-bs-toggle="dropdown"
|
||
aria-haspopup="listbox"
|
||
aria-controls="pst-version-switcher-list-3"
|
||
aria-label="Version switcher list"
|
||
>
|
||
Choose version <!-- this text may get changed later by javascript -->
|
||
<span class="caret"></span>
|
||
</button>
|
||
<div id="pst-version-switcher-list-3"
|
||
class="version-switcher__menu dropdown-menu list-group-flush py-0"
|
||
role="listbox" aria-labelledby="pst-version-switcher-button-3">
|
||
<!-- dropdown will be populated by javascript on page load -->
|
||
</div>
|
||
</div></div>
|
||
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="sidebar-header-items__end">
|
||
|
||
<div class="navbar-item">
|
||
|
||
<button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button pst-js-only" aria-label="Color mode" data-bs-title="Color mode" data-bs-placement="bottom" data-bs-toggle="tooltip">
|
||
<i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light" title="Light"></i>
|
||
<i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark" title="Dark"></i>
|
||
<i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto" title="System Settings"></i>
|
||
</button></div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
|
||
<div class="sidebar-primary-items__start sidebar-primary__section">
|
||
<div class="sidebar-primary-item">
|
||
|
||
|
||
|
||
<nav class="bd-docs-nav bd-links"
|
||
aria-label="Table of Contents">
|
||
<p class="bd-links__title" role="heading" aria-level="1">Table of Contents</p>
|
||
<div class="bd-toc-item navbar-nav"><p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../overview.html">Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../quick-start-guide.html">Quick Start Guide</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../installation/index.html">Installation</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../installation/containers.html">Pre-built release container images on NGC</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../installation/linux.html">Installing on Linux via <code class="docutils literal notranslate"><span class="pre">pip</span></code></a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Deployment Guide</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/llm_api_examples.html">LLM Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference.html">Generate text</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async.html">Generate text asynchronously</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_async_streaming.html">Generate text in streaming</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_inference_distributed.html">Distributed LLM Generation</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_guided_decoding.html">Generate text with guided decoding</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_logits_processor.html">Control generated text using logits processor</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_multilora.html">Generate text with multiple LoRA adapters</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_sparse_attention.html">Sparse Attention</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_speculative_decoding.html">Speculative Decoding</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_kv_cache_connector.html">KV Cache Connector</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_kv_cache_offloading.html">KV Cache Offloading</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_runtime.html">Runtime Configuration Examples</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_sampling.html">Sampling Techniques Showcase</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_llm_distributed.html">Run LLM-API with pytorch backend on Slurm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_bench.html">Run trtllm-bench with pytorch backend on Slurm</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/llm_mgmn_trtllm_serve.html">Run trtllm-serve with pytorch backend on Slurm</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../examples/trtllm_serve_examples.html">Online Serving Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client.html">Curl Chat Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_chat_client_for_multimodal.html">Curl Chat Client For Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/curl_completion_client.html">Curl Completion Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/deepseek_r1_reasoning_parser.html">Deepseek R1 Reasoning Parser</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client.html">Genai Perf Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/genai_perf_client_for_multimodal.html">Genai Perf Client For Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client.html">OpenAI Chat Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_chat_client_for_multimodal.html">OpenAI Chat Client for Multimodal</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client.html">OpenAI Completion Client</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client_for_lora.html">Openai Completion Client For Lora</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../examples/openai_completion_client_json_schema.html">OpenAI Completion Client with JSON Schema</a></li>
|
||
</ul>
|
||
</details></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../examples/dynamo_k8s_example.html">Dynamo K8s Example</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../deployment-guide/index.html">Model Recipes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-deepseek-r1-on-trtllm.html">Deployment Guide for DeepSeek R1 on TensorRT LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-llama3.3-70b-on-trtllm.html">Deployment Guide for Llama3.3 70B on TensorRT LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-llama4-scout-on-trtllm.html">Deployment Guide for Llama4 Scout 17B on TensorRT LLM - Blackwell & Hopper Hardware</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-gpt-oss-on-trtllm.html">Deployment Guide for GPT-OSS on TensorRT-LLM - Blackwell Hardware</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../deployment-guide/deployment-guide-for-qwen3-next-on-trtllm.html">Deployment Guide for Qwen3 Next on TensorRT LLM - Blackwell & Hopper Hardware</a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Models</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../models/supported-models.html">Supported Models</a></li>
|
||
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../models/adding-new-model.html">Adding a New Model</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">CLI Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-bench.html">trtllm-bench</a></li>
|
||
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../commands/trtllm-eval.html">trtllm-eval</a></li>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../../../commands/trtllm-serve/index.html">trtllm-serve</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../commands/trtllm-serve/trtllm-serve.html">trtllm-serve</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="../../../commands/trtllm-serve/run-benchmark-with-trtllm-serve.html">Run benchmarking with <code class="docutils literal notranslate"><span class="pre">trtllm-serve</span></code></a></li>
|
||
</ul>
|
||
</details></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">API Reference</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/index.html">LLM API Introduction</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../llm-api/reference.html">API Reference</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Features</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/feature-combination-matrix.html">Feature Combination Matrix</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/disagg-serving.html">Disaggregated Serving</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/kvcache.html">KV Cache System</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/long-sequence.html">Long Sequences</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/lora.html">LoRA (Low-Rank Adaptation)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/multi-modality.html">Multimodal Support in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/overlap-scheduler.html">Overlap Scheduler</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/paged-attention-ifb-scheduler.html">Paged Attention, IFB, and Request Scheduling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/parallel-strategy.html">Parallelism in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/quantization.html">Quantization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/sampling.html">Sampling</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/additional-outputs.html">Additional Outputs</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/speculative-decoding.html">Speculative Decoding</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/checkpoint-loading.html">Checkpoint Loading</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/auto_deploy/auto-deploy.html">AutoDeploy (Prototype)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/ray-orchestrator.html">Ray Orchestrator (Prototype)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../features/torch_compile_and_piecewise_cuda_graph.html">Torch Compile & Piecewise CUDA Graph</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Developer Guide</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/overview.html">Architecture Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/perf-analysis.html">Performance Analysis</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/perf-benchmarking.html">TensorRT LLM Benchmarking</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/ci-overview.html">Continuous Integration Overview</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/dev-containers.html">Using Dev Containers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/api-change.html">LLM API Change Guide</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../developer-guide/kv-transfer.html">Introduction to KV Cache Transmission</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Blogs</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog10_ADP_Balance_Strategy.html">ADP Balance Strategy</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog11_GPT_OSS_Eagle3.html">Running GPT-OSS-120B with Eagle3 Speculative Decoding on GB200/B200 (TensorRT LLM)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog12_Combining_Guided_Decoding_and_Speculative_Decoding.html">Combining Guided Decoding and Speculative Decoding: Making CPU and GPU Cooperate Seamlessly</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog13_Inference_Time_Compute_Implementation_in_TensorRT-LLM.html">Inference Time Compute Implementation in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.html">Scaling Expert Parallelism in TensorRT LLM (Part 3: Pushing the Performance Boundary)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.html">Pushing Latency Boundaries: Optimizing DeepSeek-R1 Performance on NVIDIA B200 GPUs</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog2_DeepSeek_R1_MTP_Implementation_and_Optimization.html">DeepSeek R1 MTP Implementation and Optimization</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog3_Optimizing_DeepSeek_R1_Throughput_on_NVIDIA_Blackwell_GPUs.html">Optimizing DeepSeek R1 Throughput on NVIDIA Blackwell GPUs: A Deep Dive for Developers</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog4_Scaling_Expert_Parallelism_in_TensorRT-LLM.html">Scaling Expert Parallelism in TensorRT LLM (Part 1: Design and Implementation of Large-scale EP)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog5_Disaggregated_Serving_in_TensorRT-LLM.html">Disaggregated Serving in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog6_Llama4_maverick_eagle_guide.html">How to launch Llama4 Maverick + Eagle3 TensorRT LLM server</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog7_NGram_performance_Analysis_And_Auto_Enablement.html">N-Gram Speculative Decoding in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.html">Scaling Expert Parallelism in TensorRT LLM (Part 2: Performance Status and Optimization)</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.html">Running a High Performance GPT-OSS-120B Inference Server with TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.html">How to get best performance on DeepSeek-R1 in TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT LLM</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Quick Links</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/releases">Releases</a></li>
|
||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM">Github Code</a></li>
|
||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap">Roadmap</a></li>
|
||
</ul>
|
||
<p aria-level="2" class="caption" role="heading"><span class="caption-text">Use TensorRT Engine</span></p>
|
||
<ul class="nav bd-sidenav">
|
||
<li class="toctree-l1"><a class="reference internal" href="../../../legacy/tensorrt_quickstart.html">LLM API with TensorRT Engine</a></li>
|
||
</ul>
|
||
</div>
|
||
</nav></div>
|
||
</div>
|
||
|
||
|
||
<div class="sidebar-primary-items__end sidebar-primary__section">
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
<main id="main-content" class="bd-main" role="main">
|
||
|
||
|
||
<div class="bd-content">
|
||
<div class="bd-article-container">
|
||
|
||
<div class="bd-header-article d-print-none">
|
||
<div class="header-article-items header-article__inner">
|
||
|
||
<div class="header-article-items__start">
|
||
|
||
<div class="header-article-item">
|
||
|
||
<nav aria-label="Breadcrumb" class="d-print-none">
|
||
<ul class="bd-breadcrumbs">
|
||
|
||
<li class="breadcrumb-item breadcrumb-home">
|
||
<a href="../../../index.html" class="nav-link" aria-label="Home">
|
||
<i class="fa-solid fa-home"></i>
|
||
</a>
|
||
</li>
|
||
|
||
<li class="breadcrumb-item"><a href="../../index.html" class="nav-link">Module code</a></li>
|
||
|
||
<li class="breadcrumb-item active" aria-current="page"><span class="ellipsis">tensorrt_llm.llmapi.llm_args</span></li>
|
||
</ul>
|
||
</nav>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
|
||
</div>
|
||
</div>
|
||
|
||
|
||
|
||
|
||
<div id="searchbox"></div>
|
||
<article class="bd-article">
|
||
|
||
<h1>Source code for tensorrt_llm.llmapi.llm_args</h1><div class="highlight"><pre>
|
||
<span></span><span class="kn">import</span><span class="w"> </span><span class="nn">ast</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">functools</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">json</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">math</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">os</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">types</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">abc</span><span class="w"> </span><span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">dataclasses</span><span class="w"> </span><span class="kn">import</span> <span class="n">dataclass</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">enum</span><span class="w"> </span><span class="kn">import</span> <span class="n">Enum</span><span class="p">,</span> <span class="n">EnumMeta</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">pathlib</span><span class="w"> </span><span class="kn">import</span> <span class="n">Path</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">Any</span><span class="p">,</span> <span class="n">ClassVar</span><span class="p">,</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Literal</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Set</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span>
|
||
<span class="n">Type</span><span class="p">,</span> <span class="n">TypeAlias</span><span class="p">,</span> <span class="n">TypeVar</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">get_args</span><span class="p">,</span> <span class="n">get_origin</span><span class="p">)</span>
|
||
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
|
||
<span class="kn">import</span><span class="w"> </span><span class="nn">yaml</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">pydantic</span><span class="w"> </span><span class="kn">import</span> <span class="n">BaseModel</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">pydantic</span><span class="w"> </span><span class="kn">import</span> <span class="n">Field</span> <span class="k">as</span> <span class="n">PydanticField</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">pydantic</span><span class="w"> </span><span class="kn">import</span> <span class="n">PrivateAttr</span><span class="p">,</span> <span class="n">field_validator</span><span class="p">,</span> <span class="n">model_validator</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">strenum</span><span class="w"> </span><span class="kn">import</span> <span class="n">StrEnum</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">PreTrainedTokenizerBase</span>
|
||
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm.lora_helper</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">LoraConfig</span><span class="p">,</span>
|
||
<span class="n">get_default_trtllm_modules_to_hf_modules</span><span class="p">)</span>
|
||
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.._utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">mpi_rank</span>
|
||
|
||
<span class="c1"># yapf: disable</span>
|
||
<span class="c1"># isort: off</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..bindings.executor</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">BatchingType</span> <span class="k">as</span> <span class="n">_BatchingType</span><span class="p">,</span>
|
||
<span class="n">CacheTransceiverBackendType</span> <span class="k">as</span> <span class="n">_CacheTransceiverBackendType</span><span class="p">,</span>
|
||
<span class="n">CacheTransceiverConfig</span> <span class="k">as</span> <span class="n">_CacheTransceiverConfig</span><span class="p">,</span>
|
||
<span class="n">CapacitySchedulerPolicy</span> <span class="k">as</span> <span class="n">_CapacitySchedulerPolicy</span><span class="p">,</span>
|
||
<span class="n">ContextChunkingPolicy</span> <span class="k">as</span> <span class="n">_ContextChunkingPolicy</span><span class="p">,</span>
|
||
<span class="n">DecodingConfig</span><span class="p">,</span>
|
||
<span class="n">DecodingMode</span><span class="p">,</span>
|
||
<span class="n">DynamicBatchConfig</span> <span class="k">as</span> <span class="n">_DynamicBatchConfig</span><span class="p">,</span>
|
||
<span class="n">EagleConfig</span> <span class="k">as</span> <span class="n">_EagleConfig</span><span class="p">,</span>
|
||
<span class="n">ExecutorConfig</span> <span class="k">as</span> <span class="n">_ExecutorConfig</span><span class="p">,</span>
|
||
<span class="n">ExtendedRuntimePerfKnobConfig</span> <span class="k">as</span> <span class="n">_ExtendedRuntimePerfKnobConfig</span><span class="p">,</span>
|
||
<span class="n">KvCacheConfig</span> <span class="k">as</span> <span class="n">_KvCacheConfig</span><span class="p">,</span>
|
||
<span class="n">LookaheadDecodingConfig</span> <span class="k">as</span> <span class="n">_LookaheadDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">PeftCacheConfig</span> <span class="k">as</span> <span class="n">_PeftCacheConfig</span><span class="p">,</span>
|
||
<span class="n">SchedulerConfig</span> <span class="k">as</span> <span class="n">_SchedulerConfig</span><span class="p">)</span> <span class="c1"># isort: skip</span>
|
||
<span class="c1"># isort: on</span>
|
||
|
||
<span class="c1"># yapf: enable</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..builder</span><span class="w"> </span><span class="kn">import</span> <span class="n">BuildConfig</span><span class="p">,</span> <span class="n">EngineConfig</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..logger</span><span class="w"> </span><span class="kn">import</span> <span class="n">logger</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..mapping</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mapping</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..models.automodel</span><span class="w"> </span><span class="kn">import</span> <span class="n">AutoConfig</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..models.modeling_utils</span><span class="w"> </span><span class="kn">import</span> <span class="p">(</span><span class="n">PretrainedConfig</span><span class="p">,</span> <span class="n">QuantAlgo</span><span class="p">,</span> <span class="n">QuantConfig</span><span class="p">,</span>
|
||
<span class="n">SpeculativeDecodingMode</span><span class="p">)</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">..sampling_params</span><span class="w"> </span><span class="kn">import</span> <span class="n">BatchedLogitsProcessor</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.build_cache</span><span class="w"> </span><span class="kn">import</span> <span class="n">BuildCacheConfig</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.tokenizer</span><span class="w"> </span><span class="kn">import</span> <span class="n">TokenizerBase</span><span class="p">,</span> <span class="n">tokenizer_factory</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">generate_api_docs_as_docstring</span><span class="p">,</span> <span class="n">get_type_repr</span>
|
||
|
||
<span class="c1"># TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import</span>
|
||
|
||
<span class="n">TypeBaseModel</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">"T"</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="n">BaseModel</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">Field</span><span class="p">(</span><span class="n">default</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
|
||
<span class="o">*</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Literal</span><span class="p">[</span><span class="s2">"prototype"</span><span class="p">,</span> <span class="s2">"beta"</span><span class="p">,</span> <span class="s2">"deprecated"</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="n">Any</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Custom Field wrapper that adds status to json_schema_extra.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> default: The default value for the field</span>
|
||
<span class="sd"> status: Optional status indicator that gets added to json_schema_extra.</span>
|
||
<span class="sd"> - None: Stable.</span>
|
||
<span class="sd"> - "beta": Recommended for use per the latest documentation.</span>
|
||
<span class="sd"> - "prototype": Not yet stable and subject to breaking changes; intended for experimentation only.</span>
|
||
<span class="sd"> **kwargs: All other arguments passed to the original Pydantic Field</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A Pydantic FieldInfo object with the status added to json_schema_extra if provided</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="k">if</span> <span class="n">status</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">json_schema_extra</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'json_schema_extra'</span><span class="p">,</span> <span class="p">{})</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">json_schema_extra</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="n">json_schema_extra</span><span class="p">[</span><span class="s1">'status'</span><span class="p">]</span> <span class="o">=</span> <span class="n">status</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="c1"># If json_schema_extra is not a dict, create a new dict with the status</span>
|
||
<span class="n">json_schema_extra</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'status'</span><span class="p">:</span> <span class="n">status</span><span class="p">}</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s1">'json_schema_extra'</span><span class="p">]</span> <span class="o">=</span> <span class="n">json_schema_extra</span>
|
||
|
||
<span class="k">return</span> <span class="n">PydanticField</span><span class="p">(</span><span class="n">default</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">StrictBaseModel</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> A base model that forbids arbitrary fields.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">Config</span><span class="p">:</span>
|
||
<span class="n">extra</span> <span class="o">=</span> <span class="s2">"forbid"</span> <span class="c1"># globally forbid arbitrary fields</span>
|
||
|
||
|
||
<div class="viewcode-block" id="CudaGraphConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CudaGraphConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">CudaGraphConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for CUDA graphs.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="c1"># List of batch sizes to create CUDA graphs for.</span>
|
||
<span class="n">batch_sizes</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"List of batch sizes to create CUDA graphs for."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Maximum batch size for CUDA graphs."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_padding</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."</span>
|
||
<span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="CudaGraphConfig.validate_cuda_graph_max_batch_size">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CudaGraphConfig.validate_cuda_graph_max_batch_size">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'max_batch_size'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_cuda_graph_max_batch_size</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Validate cuda_graph_config.max_batch_size is non-negative."""</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"cuda_graph_config.max_batch_size must be non-negative"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_generate_cuda_graph_batch_sizes</span><span class="p">(</span><span class="n">max_batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
|
||
<span class="n">enable_padding</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
||
<span class="w"> </span><span class="sd">"""Generate a list of batch sizes for CUDA graphs.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> max_batch_size: Maximum batch size to generate up to</span>
|
||
<span class="sd"> enable_padding: Whether padding is enabled, which affects the batch size distribution</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> List of batch sizes to create CUDA graphs for</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="n">enable_padding</span><span class="p">:</span>
|
||
<span class="n">batch_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="mi">8</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">17</span><span class="p">)]</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">batch_sizes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="o">+</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Add powers of 2 up to max_batch_size</span>
|
||
<span class="n">batch_sizes</span> <span class="o">+=</span> <span class="p">[</span>
|
||
<span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="c1"># Filter and sort batch sizes</span>
|
||
<span class="n">batch_sizes</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span>
|
||
<span class="p">[</span><span class="n">size</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">batch_sizes</span> <span class="k">if</span> <span class="n">size</span> <span class="o"><=</span> <span class="n">max_batch_size</span><span class="p">])</span>
|
||
|
||
<span class="c1"># Add max_batch_size if not already included</span>
|
||
<span class="k">if</span> <span class="n">max_batch_size</span> <span class="o">!=</span> <span class="n">batch_sizes</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
|
||
<span class="n">batch_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">max_batch_size</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">batch_sizes</span></div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">GuidedDecodingConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">GuidedDecodingBackend</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||
<span class="n">XGRAMMAR</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">LLGUIDANCE</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="n">backend</span><span class="p">:</span> <span class="n">GuidedDecodingBackend</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">GuidedDecodingBackend</span><span class="o">.</span><span class="n">XGRAMMAR</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The backend for guided decoding config."</span><span class="p">)</span>
|
||
<span class="n">encoded_vocab</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The encoded vocab for guided decoding config."</span><span class="p">)</span>
|
||
<span class="n">tokenizer_str</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The tokenizer string for guided decoding config."</span><span class="p">)</span>
|
||
<span class="n">stop_token_ids</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The stop token ids for guided decoding config."</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">BaseSparseAttentionConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for sparse attention.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="c1"># dispatch to the correct sparse attention config</span>
|
||
<span class="n">config_classes</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s2">"rocket"</span><span class="p">:</span> <span class="n">RocketSparseAttentionConfig</span><span class="p">,</span>
|
||
<span class="s2">"dsa"</span><span class="p">:</span> <span class="n">DeepSeekSparseAttentionConfig</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">algorithm</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"algorithm"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">algorithm</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Sparse attention algorithm is required"</span><span class="p">)</span>
|
||
|
||
<span class="n">config_class</span> <span class="o">=</span> <span class="n">config_classes</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">algorithm</span><span class="o">.</span><span class="n">lower</span><span class="p">())</span>
|
||
<span class="k">if</span> <span class="n">config_class</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid algorithm: </span><span class="si">{</span><span class="n">algorithm</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Remove 'algorithm' before passing to subclass constructor</span>
|
||
<span class="c1"># It's a ClassVar in subclasses, and used for dispatching to the correct subclass</span>
|
||
<span class="n">data</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">data</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">k</span> <span class="o">!=</span> <span class="s1">'algorithm'</span><span class="p">}</span>
|
||
<span class="k">return</span> <span class="n">config_class</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_check_fields</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">pass</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Override if the speculation algorithm does not support</span>
|
||
<span class="sd"> a subset of the possible backends.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_indices_block_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mi">1</span>
|
||
|
||
|
||
<div class="viewcode-block" id="RocketSparseAttentionConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RocketSparseAttentionConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">RocketSparseAttentionConfig</span><span class="p">(</span><span class="n">BaseSparseAttentionConfig</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for RocketKV sparse attention.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">algorithm</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"rocket"</span>
|
||
<span class="n">window_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The window size for snap KV."</span><span class="p">)</span>
|
||
<span class="n">kernel_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The kernel size for snap KV."</span><span class="p">)</span>
|
||
<span class="n">topr</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">76</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Top-r"</span><span class="p">)</span>
|
||
<span class="n">topk</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Top-k"</span><span class="p">)</span>
|
||
<span class="n">prompt_budget</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1266</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Prompt budget"</span><span class="p">)</span>
|
||
<span class="n">page_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Page size"</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="RocketSparseAttentionConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RocketSparseAttentionConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RocketSparseAttentionConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RocketSparseAttentionConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="RocketSparseAttentionConfig.get_indices_block_size">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.RocketSparseAttentionConfig.get_indices_block_size">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_indices_block_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">page_size</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="DeepSeekSparseAttentionConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DeepSeekSparseAttentionConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">DeepSeekSparseAttentionConfig</span><span class="p">(</span><span class="n">BaseSparseAttentionConfig</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for DeepSeek Sparse Attention.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">algorithm</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"dsa"</span>
|
||
<span class="n">index_n_heads</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The number of heads for the indexer."</span><span class="p">)</span>
|
||
<span class="n">index_head_dim</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The dimension of the indexer heads."</span><span class="p">)</span>
|
||
<span class="n">index_topk</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The topk for the indexer."</span><span class="p">)</span>
|
||
<span class="n">indexer_max_chunk_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The maximum chunk size for the indexer."</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="DeepSeekSparseAttentionConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DeepSeekSparseAttentionConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="DeepSeekSparseAttentionConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DeepSeekSparseAttentionConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MoeLoadBalancerConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Pydantic configuration model for the Mixture of Experts (MoE) load balancer.</span>
|
||
|
||
<span class="sd"> This model holds configuration data (`num_slots`, etc.) as well as</span>
|
||
<span class="sd"> runtime state (`_ep_rank`, `_ep_size`) which must be set via the</span>
|
||
<span class="sd"> `setup()` method before use.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="n">num_slots</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">initial_global_assignments</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="nb">repr</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># Exclude this large dict from model representation</span>
|
||
<span class="p">)</span>
|
||
<span class="n">layer_updates_per_iter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">_ep_rank</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">_ep_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="c1"># --- Methods ---</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ep_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ep_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Initializes the runtime state of the configuration.</span>
|
||
<span class="sd"> This must be called before accessing properties like `num_local_slots`.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_ep_rank</span> <span class="o">=</span> <span class="n">ep_rank</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_ep_size</span> <span class="o">=</span> <span class="n">ep_size</span>
|
||
|
||
<span class="c1"># This assertion was in the original and is critical.</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"`num_slots` cannot be None when calling setup()."</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span> <span class="o">%</span> <span class="n">ep_size</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"`num_slots` (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span><span class="si">}</span><span class="s2">) must be divisible by `ep_size` (</span><span class="si">{</span><span class="n">ep_size</span><span class="si">}</span><span class="s2">)."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># --- Computed Properties ---</span>
|
||
<span class="c1"># These properties depend on the runtime state set by setup()</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">ep_rank</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Public accessor for the private expert parallel rank."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span><span class="s2">"ep_rank is not set. Call setup() first."</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_rank</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">ep_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Public accessor for the private expert parallel size."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span><span class="s2">"ep_size is not set. Call setup() first."</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">num_local_slots</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Calculates the number of slots local to this rank."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Cannot calculate `num_local_slots`. "</span>
|
||
<span class="s2">"`num_slots` must be set and setup() must be called."</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">slot_start</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Calculates the starting global slot index for this rank."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Cannot calculate `slot_start`. Call setup() first."</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_ep_rank</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_local_slots</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">slot_end</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Calculates the ending global slot index (exclusive) for this rank."""</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">slot_start</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_local_slots</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_layer_initial_global_assignments</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span> <span class="n">layer_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Retrieves the initial global assignments for a specific layer.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_global_assignments</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="n">layer_idx</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_global_assignments</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"layer_idx </span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2"> not found in `initial_global_assignments`."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">assignments</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_global_assignments</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">]</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"`num_slots` is not set, cannot verify assignment length."</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">assignments</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Assignment length (</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">assignments</span><span class="p">)</span><span class="si">}</span><span class="s2">) for layer </span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2"> "</span>
|
||
<span class="sa">f</span><span class="s2">"does not match `num_slots` (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">num_slots</span><span class="si">}</span><span class="s2">)."</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">assignments</span>
|
||
|
||
|
||
<div class="viewcode-block" id="MoeConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MoeConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MoeConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for MoE.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">backend</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s2">"CUTLASS"</span><span class="p">,</span> <span class="s2">"CUTEDSL"</span><span class="p">,</span> <span class="s2">"WIDEEP"</span><span class="p">,</span> <span class="s2">"TRTLLM"</span><span class="p">,</span> <span class="s2">"DEEPGEMM"</span><span class="p">,</span>
|
||
<span class="s2">"VANILLA"</span><span class="p">,</span>
|
||
<span class="s2">"TRITON"</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s1">'CUTLASS'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"MoE backend to use."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_num_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"If set, at most max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds max_num_tokens, the input tensors will be split into chunks and a for loop will be used."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">load_balancer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">object</span><span class="p">,</span> <span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Configuration for MoE load balancing."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span><span class="s2">"type"</span><span class="p">:</span> <span class="s2">"Union[MoeLoadBalancerConfig, dict, str]"</span><span class="p">})</span>
|
||
|
||
<span class="n">disable_finalize_fusion</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Disable FC2+finalize kernel fusion in CUTLASS MoE backend. Setting this to True recovers deterministic numerical behavior with top-k > 2."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">use_low_precision_moe_combine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Use low precision combine in MoE operations (only for NVFP4 quantization). When enabled, uses lower precision for combining expert outputs to improve performance."</span>
|
||
<span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="MoeConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MoeConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AttentionDpConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AttentionDpConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AttentionDpConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for attention DP.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">enable_balance</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Whether to enable balance."</span><span class="p">)</span>
|
||
<span class="n">timeout_iters</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The number of iterations to timeout."</span><span class="p">)</span>
|
||
<span class="n">batching_wait_iters</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The number of iterations to wait for batching."</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="AttentionDpConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AttentionDpConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">_ParallelConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""The model distribution configs for LLM."""</span>
|
||
<span class="n">tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">pp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">cp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">gpus_per_node</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span>
|
||
<span class="c1"># Set default for MoE fields to -1 to trigger auto-calculation in Mapping</span>
|
||
<span class="n">moe_cluster_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
<span class="n">moe_tp_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
<span class="n">moe_ep_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
<span class="n">cp_config</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">)</span>
|
||
<span class="n">pp_partition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">enable_attention_dp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">enable_lm_head_tp_in_adp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
|
||
<span class="n">_devices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">devices</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_devices</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">))</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_devices</span>
|
||
|
||
<span class="nd">@devices</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">devices</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">devices</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">devices</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"devices </span><span class="si">{</span><span class="n">devices</span><span class="si">}</span><span class="s2"> should have the same length as world_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_devices</span> <span class="o">=</span> <span class="n">devices</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">world_size</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">world_size_per_node</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="n">world_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span>
|
||
<span class="n">total_nodes</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">world_size</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">world_size</span> <span class="o">//</span> <span class="n">total_nodes</span> <span class="c1">#TODO is this right?</span>
|
||
|
||
<span class="nd">@world_size</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">world_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">world_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">world_size</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"world_size </span><span class="si">{</span><span class="n">world_size</span><span class="si">}</span><span class="s2"> should be equal to tp_size * pp_size * cp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2"> "</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_multi_gpu</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">></span> <span class="mi">1</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">to_mapping</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Mapping</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">Mapping</span><span class="p">(</span><span class="n">world_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">world_size</span><span class="p">,</span>
|
||
<span class="n">rank</span><span class="o">=</span><span class="n">mpi_rank</span><span class="p">(),</span>
|
||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="n">pp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||
<span class="n">pp_partition</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pp_partition</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||
<span class="n">cp_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_config</span><span class="p">,</span>
|
||
<span class="n">enable_attention_dp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_attention_dp</span><span class="p">,</span>
|
||
<span class="n">enable_lm_head_tp_in_adp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_lm_head_tp_in_adp</span><span class="p">,</span>
|
||
<span class="n">moe_cluster_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_cluster_size</span><span class="p">,</span>
|
||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_ep_size</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="CalibConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CalibConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">CalibConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Calibration configuration.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">device</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s1">'cuda'</span><span class="p">,</span>
|
||
<span class="s1">'cpu'</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The device to run calibration."</span><span class="p">)</span>
|
||
<span class="n">calib_dataset</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="s1">'cnn_dailymail'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The name or local path of calibration dataset."</span><span class="p">)</span>
|
||
<span class="n">calib_batches</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The number of batches that the calibration runs."</span><span class="p">)</span>
|
||
<span class="n">calib_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The batch size that the calibration runs."</span><span class="p">)</span>
|
||
<span class="n">calib_max_seq_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The maximum sequence length that the calibration runs."</span><span class="p">)</span>
|
||
<span class="n">random_seed</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1234</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The random seed used for calibration."</span><span class="p">)</span>
|
||
<span class="n">tokenizer_max_seq_length</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The maximum sequence length to initialize tokenizer for calibration."</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="CalibConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CalibConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'CalibConfig'</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Create a CalibConfig instance from a dict.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> config (dict): The dict used to create CalibConfig.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> tensorrt_llm.llmapi.CalibConfig: The CalibConfig created from dict.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="CalibConfig.to_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CalibConfig.to_dict">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">dict</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Dump a CalibConfig instance to a dict.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> dict: The dict dumped from CalibConfig.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_dump</span><span class="p">()</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">_ModelFormatKind</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||
<span class="n">HF</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">TLLM_CKPT</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">TLLM_ENGINE</span> <span class="o">=</span> <span class="mi">2</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">DecodingBaseConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="c1"># The number of the drafter layers.</span>
|
||
<span class="n">max_draft_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># The number of draft tokens in the draft tokens tree.</span>
|
||
<span class="c1"># If it's a linear tree, each draft layer will only generate one draft token.</span>
|
||
<span class="c1"># In this case, max_draft_len == max_total_draft_tokens.</span>
|
||
<span class="c1"># If it's a static or dynamic tree, each draft layer may generate more than one draft token.</span>
|
||
<span class="c1"># In this case, max_total_draft_tokens >= max_draft_len.</span>
|
||
<span class="n">max_total_draft_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">speculative_model_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="c1"># PyTorch only.</span>
|
||
<span class="c1"># When specified, speculation will be disabled at batch sizes above</span>
|
||
<span class="c1"># this value. Otherwise, speculation will always be on.</span>
|
||
<span class="n">max_concurrency</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="c1"># Developer interface: dynamically adjust draft length based on active batch size in runtime.</span>
|
||
<span class="c1"># Maps batch size to draft lengths. For example:</span>
|
||
<span class="c1"># {1: 4, 4: 2, 8: 0} means:</span>
|
||
<span class="c1"># - batch_size >= 1: use draft_len=4</span>
|
||
<span class="c1"># - batch_size >= 4: use draft_len=2</span>
|
||
<span class="c1"># - batch_size >= 8: use draft_len=0 (disable speculation)</span>
|
||
<span class="c1"># draft_len_schedule is enforced to contain batch_size=1 and its according draft_len equals max_draft_len for consistency</span>
|
||
<span class="c1"># for example, if max_draft_len=4, the schedule must contain {1: 4}</span>
|
||
<span class="n">draft_len_schedule</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="n">load_format</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># PyTorch only.</span>
|
||
<span class="c1"># Rolling average window size (N) for acceptance length across completed requests.</span>
|
||
<span class="c1"># If not set or set to 0, the feature is disabled.</span>
|
||
<span class="n">acceptance_window</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># PyTorch only.</span>
|
||
<span class="c1"># Threshold for average acceptance length; speculation will be disabled</span>
|
||
<span class="c1"># permanently once the rolling average over the last N completed requests</span>
|
||
<span class="c1"># (N = acceptance_window) drops below this value.</span>
|
||
<span class="n">acceptance_length_threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="c1"># Validate acceptance controls at field level so they run on model creation</span>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'acceptance_window'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_acceptance_window</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"acceptance_window must be >= 0 (0 disables), got </span><span class="si">{</span><span class="n">v</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'acceptance_length_threshold'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_validate_acceptance_length_threshold</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"acceptance_length_threshold must be >= 0, got </span><span class="si">{</span><span class="n">v</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="c1"># If set, drafting is allowed to use chain drafter.</span>
|
||
<span class="n">_allow_chain_drafter</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="c1"># If set, drafting uses greedy sampling, irrespective of sampling parameters.</span>
|
||
<span class="n">_allow_greedy_draft_tokens</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'draft_len_schedule'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_draft_len_schedule_and_sort</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">info</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Validate and sort draft_len_schedule by batch size thresholds."""</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># Validate values</span>
|
||
<span class="k">for</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">draft_len</span> <span class="ow">in</span> <span class="n">v</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">batch_size</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"draft_len_schedule: batch size threshold must be >= 1, got </span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">draft_len</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"draft_len_schedule: draft length must be >= 0, got </span><span class="si">{</span><span class="n">draft_len</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Require batch_size=1 in schedule</span>
|
||
<span class="k">if</span> <span class="mi">1</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">v</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"draft_len_schedule must include batch_size=1. "</span>
|
||
<span class="s2">"All systems can have batch_size=1. Add {1: <max_draft_len>} to your schedule."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Enforce schedule[1] == max_draft_len for consistency</span>
|
||
<span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">info</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'max_draft_len'</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">max_draft_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">v</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">max_draft_len</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"draft_len_schedule[1] must equal max_draft_len for consistency. "</span>
|
||
<span class="sa">f</span><span class="s2">"Got schedule[1]=</span><span class="si">{</span><span class="n">v</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">, but max_draft_len=</span><span class="si">{</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s2">. "</span>
|
||
<span class="sa">f</span><span class="s2">"batch_size=1 should use maximum draft length."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Enforce all draft lengths <= max_draft_len</span>
|
||
<span class="k">if</span> <span class="n">max_draft_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">draft_len</span> <span class="ow">in</span> <span class="n">v</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">draft_len</span> <span class="o">></span> <span class="n">max_draft_len</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"draft_len_schedule: all draft lengths must be <= max_draft_len. "</span>
|
||
<span class="sa">f</span><span class="s2">"Got draft_len=</span><span class="si">{</span><span class="n">draft_len</span><span class="si">}</span><span class="s2"> for batch_size=</span><span class="si">{</span><span class="n">batch_size</span><span class="si">}</span><span class="s2">, "</span>
|
||
<span class="sa">f</span><span class="s2">"but max_draft_len=</span><span class="si">{</span><span class="n">max_draft_len</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Return sorted dict (by batch size thresholds)</span>
|
||
<span class="c1"># This ensures efficient lookup</span>
|
||
<span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">sorted</span><span class="p">(</span><span class="n">v</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="c1"># dispatch to the correct decoding config</span>
|
||
<span class="n">decoding_type</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"decoding_type"</span><span class="p">)</span>
|
||
<span class="n">config_classes</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s2">"MTP"</span><span class="p">:</span> <span class="n">MTPDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"Medusa"</span><span class="p">:</span> <span class="n">MedusaDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"Eagle"</span><span class="p">:</span> <span class="n">EagleDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"Lookahead"</span><span class="p">:</span> <span class="n">LookaheadDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"NGram"</span><span class="p">:</span> <span class="n">NGramDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"DraftTarget"</span><span class="p">:</span> <span class="n">DraftTargetDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"SaveState"</span><span class="p">:</span> <span class="n">SaveHiddenStatesDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"UserProvided"</span><span class="p">:</span> <span class="n">UserProvidedDecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"AUTO"</span><span class="p">:</span> <span class="n">AutoDecodingConfig</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">config_class</span> <span class="o">=</span> <span class="n">config_classes</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">decoding_type</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">config_class</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid decoding type: </span><span class="si">{</span><span class="n">decoding_type</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">data</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"decoding_type"</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">config_class</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_check_fields</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">pass</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Override if the speculation algorithm does not support</span>
|
||
<span class="sd"> a subset of the possible backends.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Do any additional error checking here.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">spec_dec_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="c1"># spec_dec_mode has more functionality than the raw decoding_mode string.</span>
|
||
<span class="c1"># Use an alias for the import here to avoid name collisions with the one for the</span>
|
||
<span class="c1"># TRT backend.</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm._torch.speculative.interface</span><span class="w"> </span><span class="kn">import</span> \
|
||
<span class="n">SpeculativeDecodingMode</span> <span class="k">as</span> <span class="n">TorchSpeculativeDecodingMode</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">from_string</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoding_type</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">KvCacheConnectorConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for the KV Cache Connector.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">connector_module</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="o">...</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The import path to the connector module. It will be imported with `importlib.import_module`."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">connector_scheduler_class</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="o">...</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The class name of the scheduler within the module."</span><span class="p">)</span>
|
||
<span class="n">connector_worker_class</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="o">...</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The class name of the worker within the module."</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="MedusaDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MedusaDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MedusaDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="n">medusa_choices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">num_medusa_heads</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<div class="viewcode-block" id="MedusaDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MedusaDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current Medusa only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="MedusaDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MedusaDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"Medusa"</span>
|
||
|
||
<div class="viewcode-block" id="MedusaDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MedusaDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">"pytorch"</span><span class="p">,</span> <span class="s2">"_autodeploy"</span><span class="p">)</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="EagleDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.EagleDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">EagleDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="n">eagle_choices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">greedy_sampling</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">posterior_threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># Whether to use dynamic tree.</span>
|
||
<span class="n">use_dynamic_tree</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="c1"># The topK value for each layer when enable dynamic tree.</span>
|
||
<span class="n">dynamic_tree_max_topK</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># The number of eagle layer. will not be used in pytorch flow, just for compatibility with TRT flow</span>
|
||
<span class="n">num_eagle_layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="c1"># The number of non-leaves in each layer.</span>
|
||
<span class="n">max_non_leaves_per_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="n">eagle3_one_model</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">eagle3_layers_to_capture</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Set</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<div class="viewcode-block" id="EagleDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.EagleDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||
<span class="k">for</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">attr_value</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">attr_name</span> <span class="o">==</span> <span class="s1">'max_draft_len'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span> <span class="o">=</span> <span class="n">attr_value</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="n">attr_value</span> <span class="c1"># If using linear-tree, the max_total_draft_tokens is the same as max_draft_len</span>
|
||
<span class="c1"># Convert the data type of Eagle choice from str to List[List[int]]</span>
|
||
<span class="k">if</span> <span class="n">attr_name</span> <span class="o">==</span> <span class="s1">'eagle_choices'</span> <span class="ow">and</span> <span class="n">attr_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"NOTE: The Draft token tree is still under development, PLEASE DO NOT USE IT !!!"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">attr_value</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">attr_value</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="n">attr_value</span> <span class="o">=</span> <span class="n">ast</span><span class="o">.</span><span class="n">literal_eval</span><span class="p">(</span>
|
||
<span class="n">attr_value</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">" "</span><span class="p">,</span> <span class="s2">""</span><span class="p">))</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Wrong eagle choices type. Eagle choices should be a List[List[int]] or a string like [[0], [1], [2], [0, 0], [0, 1]]."</span>
|
||
<span class="p">)</span>
|
||
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr_name</span><span class="p">,</span> <span class="n">attr_value</span><span class="p">)</span>
|
||
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"max_draft_len is required for Eagle"</span>
|
||
|
||
<span class="c1"># Static tree logic</span>
|
||
<span class="c1"># Checks whether the input eagle choices is valid</span>
|
||
<span class="c1"># and reset the max_draft_len and num_eagle_layers if necessary</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="c1"># If eagle_choices is provided, use_dynamic_tree should not be used</span>
|
||
<span class="k">assert</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_dynamic_tree</span><span class="p">,</span> <span class="s2">"If eagle_choices is provided, use_dynamic_tree need to be False"</span>
|
||
|
||
<span class="c1"># Get num_eagle_layers from eagle_choices</span>
|
||
<span class="n">num_eagle_layers_from_choices</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">check_eagle_choices</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">num_eagle_layers_from_choices</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Base on the input choices, reset the num_eagle_layers(max_draft_len) from </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span><span class="si">}</span><span class="s2"> to </span><span class="si">{</span><span class="n">num_eagle_layers_from_choices</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">num_eagle_layers</span> <span class="o">=</span> <span class="n">num_eagle_layers_from_choices</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">num_eagle_layers_from_choices</span>
|
||
|
||
<span class="c1"># Each draft node has a path(choice) from the root to it.</span>
|
||
<span class="c1"># So the number of choices also represents the number of max draft nodes.</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Dynamic tree logic</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_dynamic_tree</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"If use_dynamic_tree is True, eagle_choices should be None"</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"max_draft_len should be provided, which indicates the number of drafter layers"</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_tree_max_topK</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_tree_max_topK</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time"</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">></span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)"</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="EagleDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.EagleDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"Eagle"</span>
|
||
|
||
<div class="viewcode-block" id="EagleDecodingConfig.validate">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.EagleDecodingConfig.validate">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_model_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Draft model must be provided for EAGLE"</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="EagleDecodingConfig.check_eagle_choices">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.EagleDecodingConfig.check_eagle_choices">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">check_eagle_choices</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="c1"># 1) Check connectivity</span>
|
||
<span class="n">unique_choices</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span>
|
||
<span class="nb">tuple</span><span class="p">(</span><span class="n">sub_choice</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">sub_choice</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">)</span> <span class="c1"># remove repeated choices</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">([</span><span class="nb">list</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">unique_choices</span><span class="p">],</span>
|
||
<span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">x</span><span class="p">))</span> <span class="c1"># sort choices</span>
|
||
<span class="k">for</span> <span class="n">choice</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">choice</span><span class="p">)</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="n">choice</span><span class="p">[</span>
|
||
<span class="mi">0</span><span class="p">:</span>
|
||
<span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"Error: choice </span><span class="si">{</span><span class="n">choice</span><span class="si">}</span><span class="s2"> is not connected"</span>
|
||
|
||
<span class="c1"># 2) Get num_eagle_layers_from_choices</span>
|
||
<span class="n">num_eagle_layers_from_choices</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span>
|
||
<span class="nb">len</span><span class="p">(</span><span class="n">choice</span><span class="p">)</span> <span class="k">for</span> <span class="n">choice</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="n">num_eagle_layers_from_choices</span></div>
|
||
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">spec_dec_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm._torch.speculative.interface</span><span class="w"> </span><span class="kn">import</span> \
|
||
<span class="n">SpeculativeDecodingMode</span> <span class="k">as</span> <span class="n">TorchSpeculativeDecodingMode</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_one_model</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE3_ONE_MODEL</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE3</span>
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">num_capture_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Returns the number of layers to capture of the target model.</span>
|
||
<span class="sd"> If eagle3_layers_to_capture is not None, return the length of the set.</span>
|
||
<span class="sd"> Otherwise, assume Eagle3 base set and return 3.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="mi">3</span>
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_linear_tree</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle_choices</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_dynamic_tree</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
<span class="k">return</span> <span class="kc">False</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SaveHiddenStatesDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.SaveHiddenStatesDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SaveHiddenStatesDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="n">output_directory</span><span class="p">:</span> <span class="nb">str</span>
|
||
<span class="n">write_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span>
|
||
<span class="n">file_prefix</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">"data"</span>
|
||
<span class="n">eagle3_layers_to_capture</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Set</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="n">max_total_draft_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
<span class="n">eagle_choices</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="SaveHiddenStatesDecodingConfig.model_post_init">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.SaveHiddenStatesDecodingConfig.model_post_init">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">model_post_init</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__context</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_last_hidden_in_save</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_last_hidden_in_save</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="k">elif</span> <span class="o">-</span><span class="mi">1</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_last_hidden_in_save</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="SaveHiddenStatesDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.SaveHiddenStatesDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"SaveState"</span>
|
||
|
||
<div class="viewcode-block" id="SaveHiddenStatesDecodingConfig.validate">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.SaveHiddenStatesDecodingConfig.validate">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_directory</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Save directory and layers to capture must be provided"</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">spec_dec_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm._torch.speculative.interface</span><span class="w"> </span><span class="kn">import</span> \
|
||
<span class="n">SpeculativeDecodingMode</span> <span class="k">as</span> <span class="n">TorchSpeculativeDecodingMode</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">SAVE_HIDDEN_STATES</span>
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">num_capture_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Returns the number of layers to capture of the target model.</span>
|
||
<span class="sd"> If eagle3_layers_to_capture is not None, return the length of the set.</span>
|
||
<span class="sd"> Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mi">4</span>
|
||
<span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eagle3_layers_to_capture</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="UserProvidedDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.UserProvidedDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">UserProvidedDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="c1"># Cannot use real type annotations due to circular imports</span>
|
||
<span class="n">drafter</span><span class="p">:</span> <span class="nb">object</span> <span class="c1"># Type is Drafter</span>
|
||
<span class="n">resource_manager</span><span class="p">:</span> <span class="nb">object</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># Type is Optional[ResourceManager]</span>
|
||
|
||
<div class="viewcode-block" id="UserProvidedDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.UserProvidedDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current UserProvided only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="UserProvidedDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.UserProvidedDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"User_Provided"</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="NGramDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.NGramDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">NGramDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for NGram drafter speculative decoding.</span>
|
||
|
||
<span class="sd"> Arguments:</span>
|
||
<span class="sd"> max_draft_len: int</span>
|
||
<span class="sd"> The length maximum of draft tokens (can be understood as length maximum of output draft tokens).</span>
|
||
|
||
<span class="sd"> max_matching_ngram_size: int</span>
|
||
<span class="sd"> The length maximum of searching tokens (can be understood as length maximum of input tokens to search).</span>
|
||
|
||
<span class="sd"> is_keep_all: bool = True</span>
|
||
<span class="sd"> Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.</span>
|
||
|
||
<span class="sd"> is_use_oldest: bool = True</span>
|
||
<span class="sd"> Whether to provide the oldest match when pattern is hit, the newest one is provided if False.</span>
|
||
|
||
<span class="sd"> is_public_pool: bool = True</span>
|
||
<span class="sd"> Whether to use a common pool for all requests, or the pool is private for each request if False.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">max_matching_ngram_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">is_keep_all</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">is_use_oldest</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="n">is_public_pool</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<div class="viewcode-block" id="NGramDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.NGramDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current NGram only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="NGramDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.NGramDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"NGram"</span>
|
||
|
||
<div class="viewcode-block" id="NGramDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.NGramDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="DraftTargetDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DraftTargetDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">DraftTargetDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
|
||
<div class="viewcode-block" id="DraftTargetDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DraftTargetDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current DraftTarget only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="DraftTargetDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DraftTargetDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"Draft_Target"</span>
|
||
|
||
<div class="viewcode-block" id="DraftTargetDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DraftTargetDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="MTPDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MTPDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">MTPDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="n">num_nextn_predict_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">use_relaxed_acceptance_for_thinking</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">relaxed_topk</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="n">relaxed_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span>
|
||
<span class="n">use_mtp_vanilla</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="n">mtp_eagle_one_model</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="c1"># TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`</span>
|
||
<span class="c1"># Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.</span>
|
||
<span class="n">num_nextn_predict_layers_from_model_config</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="c1"># When encounter <think>, start thinking phase.</span>
|
||
<span class="c1"># When encounter </think>, end thinking phase.</span>
|
||
<span class="c1"># <think> [thinking phase] </think> [real output]</span>
|
||
<span class="n">begin_thinking_phase_token</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128798</span>
|
||
<span class="n">end_thinking_phase_token</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128799</span>
|
||
|
||
<div class="viewcode-block" id="MTPDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MTPDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="s1">'num_nextn_predict_layers'</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s1">'num_nextn_predict_layers'</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span>
|
||
<span class="s1">'num_nextn_predict_layers'</span><span class="p">]</span> <span class="c1"># Current MTP only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="MTPDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MTPDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="n">out</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
|
||
<span class="n">out</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">num_nextn_predict_layers</span>
|
||
<span class="n">out</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="n">out</span><span class="o">.</span><span class="n">num_nextn_predict_layers</span> <span class="c1"># Current MTP only support linear tree</span>
|
||
<span class="k">return</span> <span class="n">out</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"MTP"</span>
|
||
|
||
<div class="viewcode-block" id="MTPDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.MTPDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">num_capture_layers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">int</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mtp_vanilla</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">mtp_eagle_one_model</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="mi">1</span>
|
||
<span class="k">return</span> <span class="mi">0</span>
|
||
|
||
<span class="nd">@functools</span><span class="o">.</span><span class="n">cached_property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">spec_dec_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="kn">from</span><span class="w"> </span><span class="nn">tensorrt_llm._torch.speculative.interface</span><span class="w"> </span><span class="kn">import</span> \
|
||
<span class="n">SpeculativeDecodingMode</span> <span class="k">as</span> <span class="n">TorchSpeculativeDecodingMode</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_nextn_predict_layers_from_model_config</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mtp_vanilla</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mtp_eagle_one_model</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">MTP_EAGLE_ONE_MODEL</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_nextn_predict_layers_from_model_config</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_mtp_vanilla</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">mtp_eagle_one_model</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">MTP_EAGLE</span>
|
||
<span class="k">return</span> <span class="n">TorchSpeculativeDecodingMode</span><span class="o">.</span><span class="n">MTP</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="AutoDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AutoDecodingConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">AutoDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for auto speculative decoding.</span>
|
||
|
||
<span class="sd"> This config will automatically select a good, draft-model free</span>
|
||
<span class="sd"> speculation algorithm with some heuristic.</span>
|
||
|
||
<span class="sd"> Attributes that are inherited from the base class are ignored.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<div class="viewcode-block" id="AutoDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AutoDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current Auto only support linear tree</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="AutoDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AutoDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"AUTO"</span>
|
||
|
||
<div class="viewcode-block" id="AutoDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.AutoDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">PybindMirror</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' A class containing the utilities for mirroring Python classes to</span>
|
||
<span class="sd"> pybinding classes.</span>
|
||
<span class="sd"> '''</span>
|
||
|
||
<span class="nd">@abstractmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">pass</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">maybe_to_pybind</span><span class="p">(</span><span class="n">ins</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="n">ins</span><span class="p">,</span>
|
||
<span class="n">PybindMirror</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">type</span><span class="p">(</span><span class="n">ins</span><span class="p">)</span><span class="o">.</span><span class="vm">__class__</span> <span class="o">==</span> <span class="n">PybindMirrorEnumMeta</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">ins</span><span class="o">.</span><span class="n">_to_pybind</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="n">ins</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mirror_pybind_fields</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Class decorator that ensures Python class fields mirror those of a C++ class.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> pybind_class: The C++ class whose fields should be mirrored</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> A decorator function that validates field mirroring</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">decorator</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">issubclass</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">StrictBaseModel</span><span class="p">)</span>
|
||
<span class="c1"># Get all non-private fields from the C++ class</span>
|
||
<span class="n">cpp_fields</span> <span class="o">=</span> <span class="n">PybindMirror</span><span class="o">.</span><span class="n">get_pybind_variable_fields</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">)</span>
|
||
<span class="n">python_fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">cls</span><span class="o">.</span><span class="n">model_fields</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
|
||
<span class="c1"># Check if all C++ fields exist in the Python class</span>
|
||
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">cpp_fields</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">python_fields</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s2"> is not mirrored in Python class </span><span class="si">{</span><span class="bp">cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> from C++ class </span><span class="si">{</span><span class="n">pybind_class</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">. Please update the class."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># Return the original class</span>
|
||
<span class="k">return</span> <span class="bp">cls</span>
|
||
|
||
<span class="k">return</span> <span class="n">decorator</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_pybind_enum_fields</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get all the enum fields from the pybind class. '''</span>
|
||
<span class="k">return</span> <span class="p">[</span>
|
||
<span class="n">f</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">pybind_class</span><span class="o">.</span><span class="n">__members__</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">f</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'_'</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">callable</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">mirror_pybind_enum</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Mirror the enum fields from the pybind class to the Python class. '''</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">decorator</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="nb">issubclass</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">Enum</span><span class="p">)</span>
|
||
<span class="n">cpp_fields</span> <span class="o">=</span> <span class="n">PybindMirror</span><span class="o">.</span><span class="n">get_pybind_enum_fields</span><span class="p">(</span><span class="n">pybind_class</span><span class="p">)</span>
|
||
<span class="n">python_fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">cls</span><span class="o">.</span><span class="n">__members__</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
|
||
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">cpp_fields</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">python_fields</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s2"> is not mirrored in Python class </span><span class="si">{</span><span class="bp">cls</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2"> from C++ class </span><span class="si">{</span><span class="n">pybind_class</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">. Please update the class."</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">cls</span>
|
||
|
||
<span class="k">return</span> <span class="n">decorator</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_pybind_variable_fields</span><span class="p">(</span><span class="n">config_cls</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Get all the variable fields from the pybind class. '''</span>
|
||
<span class="k">return</span> <span class="p">[</span>
|
||
<span class="n">f</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="nb">dir</span><span class="p">(</span><span class="n">config_cls</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">f</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'_'</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">callable</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="n">config_cls</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span>
|
||
<span class="p">]</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">pybind_equals</span><span class="p">(</span><span class="n">obj0</span><span class="p">,</span> <span class="n">obj1</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Check if two pybind objects are equal. '''</span>
|
||
<span class="k">assert</span> <span class="nb">type</span><span class="p">(</span><span class="n">obj0</span><span class="p">)</span> <span class="ow">is</span> <span class="nb">type</span><span class="p">(</span><span class="n">obj1</span><span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="n">PybindMirror</span><span class="o">.</span><span class="n">get_pybind_variable_fields</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">obj0</span><span class="p">)):</span>
|
||
<span class="k">if</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">obj0</span><span class="p">,</span> <span class="n">field</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">obj1</span><span class="p">,</span> <span class="n">field</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="kc">False</span>
|
||
<span class="k">return</span> <span class="kc">True</span>
|
||
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_pybind</span><span class="p">(</span><span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">TypeBaseModel</span><span class="p">],</span>
|
||
<span class="n">pybind_instance</span><span class="p">:</span> <span class="s2">"PybindMirror"</span><span class="p">)</span> <span class="o">-></span> <span class="n">TypeBaseModel</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Construct an instance of the given class from the fields in the given</span>
|
||
<span class="sd"> pybind class instance.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> cls: Type of the class to construct, must be a subclass of pydantic</span>
|
||
<span class="sd"> BaseModel</span>
|
||
<span class="sd"> pybind_instance: Instance of the pybind class to construct from its</span>
|
||
<span class="sd"> fields</span>
|
||
|
||
<span class="sd"> Notes:</span>
|
||
<span class="sd"> When a field value is None in the pybind class, but it's not</span>
|
||
<span class="sd"> optional and has a default value in the BaseModel class, it would</span>
|
||
<span class="sd"> get the default value defined in the BaseModel class.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> Instance of the given class, populated with the fields of the given</span>
|
||
<span class="sd"> pybind instance</span>
|
||
<span class="sd"> """</span> <span class="c1"># noqa: D205</span>
|
||
<span class="k">assert</span> <span class="nb">issubclass</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">BaseModel</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Some of the fields are optional in the C++ class but in python they aren't</span>
|
||
<span class="c1"># optional and have a default value, so copy the value from C++ instance</span>
|
||
<span class="c1"># only if it has a value, so otherwise the default value defined in the</span>
|
||
<span class="c1"># python class would be set.</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_is_optional_type</span><span class="p">(</span><span class="n">annotation</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Returns True if a type annotation represents an Optional type</span>
|
||
<span class="sd"> (Optional[X]) or a Union type that includes None (Union[X, Y, None]</span>
|
||
<span class="sd"> or X | Y | None).</span>
|
||
<span class="sd"> """</span> <span class="c1"># noqa: D205</span>
|
||
<span class="n">origin</span> <span class="o">=</span> <span class="n">get_origin</span><span class="p">(</span><span class="n">annotation</span><span class="p">)</span>
|
||
<span class="n">args</span> <span class="o">=</span> <span class="n">get_args</span><span class="p">(</span><span class="n">annotation</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Union is for Optional[x]</span>
|
||
<span class="c1"># UnionType is for the new | operation in Python 3.10+</span>
|
||
<span class="k">return</span> <span class="p">(</span><span class="n">origin</span> <span class="ow">is</span> <span class="n">Union</span>
|
||
<span class="ow">or</span> <span class="n">origin</span> <span class="ow">is</span> <span class="n">types</span><span class="o">.</span><span class="n">UnionType</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">type</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span> <span class="ow">in</span> <span class="n">args</span>
|
||
|
||
<span class="n">fields_non_optional_with_default_value_in_basemodel</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="n">field_name</span>
|
||
<span class="k">for</span> <span class="n">field_name</span><span class="p">,</span> <span class="n">field_info</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">model_fields</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">_is_optional_type</span><span class="p">(</span><span class="n">field_info</span><span class="o">.</span><span class="n">annotation</span><span class="p">)</span>
|
||
<span class="ow">and</span> <span class="n">field_info</span><span class="o">.</span><span class="n">is_required</span><span class="p">())</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="n">cpp_fields</span> <span class="o">=</span> <span class="n">PybindMirror</span><span class="o">.</span><span class="n">get_pybind_variable_fields</span><span class="p">(</span>
|
||
<span class="nb">type</span><span class="p">(</span><span class="n">pybind_instance</span><span class="p">))</span>
|
||
<span class="k">for</span> <span class="n">field_name</span> <span class="ow">in</span> <span class="n">cpp_fields</span><span class="p">:</span>
|
||
<span class="n">field_value</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">pybind_instance</span><span class="p">,</span> <span class="n">field_name</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">field_value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">field_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">fields_non_optional_with_default_value_in_basemodel</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="n">field_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">field_value</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">PybindMirrorMeta</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">PybindMirror</span><span class="p">)):</span>
|
||
<span class="k">pass</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">PybindMirrorEnumMeta</span><span class="p">(</span><span class="n">EnumMeta</span><span class="p">,</span> <span class="n">PybindMirrorMeta</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Combined metaclass for Enum and PybindMirror. This is crucial.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
|
||
<div class="viewcode-block" id="BatchingType">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.BatchingType">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_enum</span><span class="p">(</span><span class="n">_BatchingType</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">BatchingType</span><span class="p">(</span><span class="n">StrEnum</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">PybindMirrorEnumMeta</span><span class="p">):</span>
|
||
<span class="n">STATIC</span> <span class="o">=</span> <span class="s2">"STATIC"</span>
|
||
<span class="n">INFLIGHT</span> <span class="o">=</span> <span class="s2">"INFLIGHT"</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">_BatchingType</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="CapacitySchedulerPolicy">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CapacitySchedulerPolicy">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_enum</span><span class="p">(</span><span class="n">_CapacitySchedulerPolicy</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">CapacitySchedulerPolicy</span><span class="p">(</span><span class="n">StrEnum</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">PybindMirrorEnumMeta</span><span class="p">):</span>
|
||
<span class="n">MAX_UTILIZATION</span> <span class="o">=</span> <span class="s2">"MAX_UTILIZATION"</span>
|
||
<span class="n">GUARANTEED_NO_EVICT</span> <span class="o">=</span> <span class="s2">"GUARANTEED_NO_EVICT"</span>
|
||
<span class="n">STATIC_BATCH</span> <span class="o">=</span> <span class="s2">"STATIC_BATCH"</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">_CapacitySchedulerPolicy</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="ContextChunkingPolicy">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.ContextChunkingPolicy">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_enum</span><span class="p">(</span><span class="n">_ContextChunkingPolicy</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">ContextChunkingPolicy</span><span class="p">(</span><span class="n">StrEnum</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">PybindMirrorEnumMeta</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">''' Context chunking policy. '''</span>
|
||
<span class="n">FIRST_COME_FIRST_SERVED</span> <span class="o">=</span> <span class="s2">"FIRST_COME_FIRST_SERVED"</span>
|
||
<span class="n">EQUAL_PROGRESS</span> <span class="o">=</span> <span class="s2">"EQUAL_PROGRESS"</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">_ContextChunkingPolicy</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="DynamicBatchConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.DynamicBatchConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_DynamicBatchConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">DynamicBatchConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Dynamic batch configuration.</span>
|
||
|
||
<span class="sd"> Controls how batch size and token limits are dynamically adjusted at runtime.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">enable_batch_size_tuning</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Controls if the batch size should be tuned dynamically"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_max_num_tokens_tuning</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Controls if the max num tokens should be tuned dynamically"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">dynamic_batch_moving_average_window</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The window size for moving average of input and output length which is used to calculate dynamic batch size and max num tokens"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_DynamicBatchConfig</span><span class="p">(</span>
|
||
<span class="n">enable_batch_size_tuning</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_batch_size_tuning</span><span class="p">,</span>
|
||
<span class="n">enable_max_num_tokens_tuning</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_max_num_tokens_tuning</span><span class="p">,</span>
|
||
<span class="n">dynamic_batch_moving_average_window</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
|
||
<span class="n">dynamic_batch_moving_average_window</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="SchedulerConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.SchedulerConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_SchedulerConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SchedulerConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="n">capacity_scheduler_policy</span><span class="p">:</span> <span class="n">CapacitySchedulerPolicy</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">CapacitySchedulerPolicy</span><span class="o">.</span><span class="n">GUARANTEED_NO_EVICT</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The capacity scheduler policy to use"</span><span class="p">)</span>
|
||
|
||
<span class="n">context_chunking_policy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">ContextChunkingPolicy</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The context chunking policy to use"</span><span class="p">)</span>
|
||
|
||
<span class="n">dynamic_batch_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">DynamicBatchConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The dynamic batch config to use"</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_SchedulerConfig</span><span class="p">(</span>
|
||
<span class="n">capacity_scheduler_policy</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">capacity_scheduler_policy</span><span class="o">.</span><span class="n">_to_pybind</span><span class="p">(</span>
|
||
<span class="p">),</span>
|
||
<span class="n">context_chunking_policy</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">context_chunking_policy</span><span class="o">.</span><span class="n">_to_pybind</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_chunking_policy</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">dynamic_batch_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dynamic_batch_config</span><span class="o">.</span><span class="n">_to_pybind</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_batch_config</span> <span class="k">else</span> <span class="kc">None</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_PeftCacheConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">PeftCacheConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for the PEFT cache.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">num_host_module_layer</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"</span>
|
||
<span class="s2">", affects host cache size and overrides value of host_cache_size"</span><span class="p">)</span>
|
||
<span class="n">num_device_module_layer</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"number of max sized 1-layer 1-module sets of weights that can be stored in device cache"</span>
|
||
<span class="s2">", affects device cache size and overrides value of device_cache_percent"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">optimal_adapter_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span>
|
||
<span class="mi">8</span><span class="p">,</span> <span class="c1"># There are tests to keep the default value consistent with the pybind default value</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"optimal adapter size used to set page width"</span><span class="p">)</span>
|
||
<span class="n">max_adapter_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"max supported adapter size. Used to compute minimum"</span><span class="p">)</span>
|
||
<span class="n">num_put_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"number of worker threads used to put weights into host cache"</span><span class="p">)</span>
|
||
<span class="n">num_ensure_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"number of worker threads used to copy weights from host to device"</span><span class="p">)</span>
|
||
<span class="n">num_copy_streams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"number of streams used to copy weights from host to device"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">max_pages_per_block_host</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">24</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Number of cache pages per allocation block (host)"</span><span class="p">)</span>
|
||
<span class="n">max_pages_per_block_device</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Number of cache pages per allocation block (device)"</span><span class="p">)</span>
|
||
<span class="n">device_cache_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mf">0.02</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"</span>
|
||
<span class="p">)</span>
|
||
<span class="n">host_cache_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1024</span><span class="o">**</span><span class="mi">3</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"size in bytes to use for host cache"</span><span class="p">)</span>
|
||
<span class="n">lora_prefetch_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"folder to store the LoRA weights we hope to load during engine initialization, currently not supported"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_PeftCacheConfig</span><span class="p">(</span>
|
||
<span class="n">num_host_module_layer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_host_module_layer</span><span class="p">,</span>
|
||
<span class="n">num_device_module_layer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_device_module_layer</span><span class="p">,</span>
|
||
<span class="n">optimal_adapter_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimal_adapter_size</span><span class="p">,</span>
|
||
<span class="n">max_adapter_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_adapter_size</span><span class="p">,</span>
|
||
<span class="n">num_put_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_put_workers</span><span class="p">,</span>
|
||
<span class="n">num_ensure_workers</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_ensure_workers</span><span class="p">,</span>
|
||
<span class="n">num_copy_streams</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_copy_streams</span><span class="p">,</span>
|
||
<span class="n">max_pages_per_block_host</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_pages_per_block_host</span><span class="p">,</span>
|
||
<span class="n">max_pages_per_block_device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_pages_per_block_device</span><span class="p">,</span>
|
||
<span class="n">device_cache_percent</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device_cache_percent</span><span class="p">,</span>
|
||
<span class="n">host_cache_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">host_cache_size</span><span class="p">,</span>
|
||
<span class="n">lora_prefetch_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_prefetch_dir</span><span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_LookaheadDecodingConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">LookaheadDecodingConfig</span><span class="p">(</span><span class="n">DecodingBaseConfig</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for lookahead speculative decoding.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="n">max_window_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">_LookaheadDecodingConfig</span><span class="o">.</span><span class="n">get_default_lookahead_decoding_window</span><span class="p">(</span>
|
||
<span class="p">),</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Number of NGrams in lookahead branch per step."</span><span class="p">)</span>
|
||
<span class="n">max_ngram_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">_LookaheadDecodingConfig</span><span class="o">.</span><span class="n">get_default_lookahead_decoding_ngram</span><span class="p">(),</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Number of tokens per NGram."</span><span class="p">)</span>
|
||
<span class="n">max_verification_set_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">_LookaheadDecodingConfig</span><span class="o">.</span>
|
||
<span class="n">get_default_lookahead_decoding_verification_set</span><span class="p">(),</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Number of NGrams in verification branch per step."</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig.validate_positive_values">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig.validate_positive_values">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'max_window_size'</span><span class="p">,</span> <span class="s1">'max_ngram_size'</span><span class="p">,</span>
|
||
<span class="s1">'max_verification_set_size'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_positive_values</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Value must be positive, got </span><span class="si">{</span><span class="n">v</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig.__init__">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig.__init__">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">data</span><span class="p">):</span>
|
||
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_total_draft_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="c1"># Current Lookahead only support linear tree</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_check_fields</span><span class="p">()</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig.calculate_speculative_resource">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig.calculate_speculative_resource">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">calculate_speculative_resource</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_LookaheadDecodingConfig</span><span class="o">.</span><span class="n">calculate_speculative_resource_tuple</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_window_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_ngram_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_verification_set_size</span><span class="p">)</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig.from_dict">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig.from_dict">[docs]</a>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">data</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_LookaheadDecodingConfig</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">max_window_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_ngram_size</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_verification_set_size</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="LookaheadDecodingConfig.supports_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.LookaheadDecodingConfig.supports_backend">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">backend</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">backend</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">"pytorch"</span><span class="p">,</span> <span class="s2">"_autodeploy"</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">decoding_type</span><span class="p">:</span> <span class="n">ClassVar</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"Lookahead"</span></div>
|
||
|
||
|
||
|
||
<span class="n">SpeculativeConfig</span><span class="p">:</span> <span class="n">TypeAlias</span> <span class="o">=</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span>
|
||
<span class="n">DraftTargetDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">EagleDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">LookaheadDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">MedusaDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">MTPDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">NGramDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">UserProvidedDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">SaveHiddenStatesDecodingConfig</span><span class="p">,</span>
|
||
<span class="n">AutoDecodingConfig</span><span class="p">,</span>
|
||
<span class="p">]]</span>
|
||
|
||
<span class="n">SparseAttentionConfig</span><span class="p">:</span> <span class="n">TypeAlias</span> <span class="o">=</span> <span class="n">Union</span><span class="p">[</span>
|
||
<span class="n">RocketSparseAttentionConfig</span><span class="p">,</span>
|
||
<span class="n">DeepSeekSparseAttentionConfig</span><span class="p">,</span>
|
||
<span class="p">]</span>
|
||
|
||
|
||
<div class="viewcode-block" id="KvCacheConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_KvCacheConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">KvCacheConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for the KV cache.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">enable_block_reuse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Controls if KV cache blocks can be reused for different requests."</span><span class="p">)</span>
|
||
<span class="n">max_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The maximum number of tokens that should be stored in the KV cache. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">max_attention_window</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Size of the attention window for each sequence. Only the last tokens will be stored in the KV cache. If the number of elements in `max_attention_window` is less than the number of layers, `max_attention_window` will be repeated multiple times to the number of layers."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">sink_token_length</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Number of sink tokens (tokens to always keep in attention window)."</span><span class="p">)</span>
|
||
<span class="n">free_gpu_memory_fraction</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The fraction of GPU memory fraction that should be allocated for the KV cache. Default is 90%. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">host_cache_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Size of the host cache in bytes. If both `max_tokens` and `host_cache_size` are specified, memory corresponding to the minimum will be used."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">onboard_blocks</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Controls if blocks are onboarded."</span><span class="p">)</span>
|
||
<span class="n">cross_kv_cache_fraction</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The fraction of the KV Cache memory should be reserved for cross attention. If set to p, self attention will use 1-p of KV Cache memory and cross attention will use p of KV Cache memory. Default is 50%. Should only be set when using encoder-decoder model."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">secondary_offload_min_priority</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">event_buffer_max_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Maximum size of the event buffer. If set to 0, the event buffer will not be used."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">attention_dp_events_gather_period_ms</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The period in milliseconds to gather attention DP events across ranks."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">enable_partial_reuse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Whether blocks that are only partially matched can be reused."</span><span class="p">)</span>
|
||
<span class="n">copy_on_partial_reuse</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Whether partially matched blocks that are in use can be reused after copying them."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">use_uvm</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Whether to use UVM for the KV cache."</span><span class="p">)</span>
|
||
<span class="n">max_gpu_total_bytes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The maximum size in bytes of GPU memory that can be allocated for the KV cache. If both `max_gpu_total_bytes` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be allocated."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># This is a pure python field, not a pybind field. It is only for the Pytorch backend.</span>
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The data type to use for the KV cache."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># This is a pure python field, not a pybind field. It is only for the Pytorch backend.</span>
|
||
<span class="n">mamba_ssm_cache_dtype</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span>
|
||
<span class="s2">"auto"</span><span class="p">,</span> <span class="s2">"float16"</span><span class="p">,</span> <span class="s2">"bfloat16"</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">tokens_per_block</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The number of tokens per block."</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_KvCacheConfig</span><span class="p">(</span>
|
||
<span class="n">enable_block_reuse</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_block_reuse</span><span class="p">,</span>
|
||
<span class="n">max_tokens</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_tokens</span><span class="p">,</span>
|
||
<span class="n">max_attention_window</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_attention_window</span><span class="p">,</span>
|
||
<span class="n">sink_token_length</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sink_token_length</span><span class="p">,</span>
|
||
<span class="n">free_gpu_memory_fraction</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">free_gpu_memory_fraction</span><span class="p">,</span>
|
||
<span class="n">host_cache_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">host_cache_size</span><span class="p">,</span>
|
||
<span class="n">onboard_blocks</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">onboard_blocks</span><span class="p">,</span>
|
||
<span class="n">cross_kv_cache_fraction</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cross_kv_cache_fraction</span><span class="p">,</span>
|
||
<span class="n">secondary_offload_min_priority</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">secondary_offload_min_priority</span><span class="p">,</span>
|
||
<span class="n">event_buffer_max_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">event_buffer_max_size</span><span class="p">,</span>
|
||
<span class="n">enable_partial_reuse</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_partial_reuse</span><span class="p">,</span>
|
||
<span class="n">copy_on_partial_reuse</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">copy_on_partial_reuse</span><span class="p">,</span>
|
||
<span class="n">use_uvm</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">use_uvm</span><span class="p">,</span>
|
||
<span class="n">attention_dp_events_gather_period_ms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
|
||
<span class="n">attention_dp_events_gather_period_ms</span><span class="p">,</span>
|
||
<span class="n">max_gpu_total_bytes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_gpu_total_bytes</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="KvCacheConfig.validate_free_gpu_memory_fraction">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig.validate_free_gpu_memory_fraction">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'free_gpu_memory_fraction'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_free_gpu_memory_fraction</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Validates that the fraction is between 0.0 and 1.0."""</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">v</span> <span class="o"><=</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_config.free_gpu_memory_fraction must be a float between 0 and 1"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="KvCacheConfig.validate_max_gpu_total_bytes">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig.validate_max_gpu_total_bytes">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'max_gpu_total_bytes'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_max_gpu_total_bytes</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_config.max_gpu_total_bytes must be non-negative"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="KvCacheConfig.validate_max_attention_window">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig.validate_max_attention_window">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'max_attention_window'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_max_attention_window</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]):</span>
|
||
<span class="c1"># Allow unset</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="c1"># Must be a non-empty list of positive integers</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_config.max_attention_window must be a non-empty list of positive integers"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">v</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_config.max_attention_window must contain only integers"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">i</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"kv_cache_config.max_attention_window values must be positive"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="ExtendedRuntimePerfKnobConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.ExtendedRuntimePerfKnobConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_ExtendedRuntimePerfKnobConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">ExtendedRuntimePerfKnobConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for extended runtime performance knobs.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="n">multi_block_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Whether to use multi-block mode."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_context_fmha_fp32_acc</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Whether to enable context FMHA FP32 accumulation."</span><span class="p">)</span>
|
||
|
||
<span class="n">cuda_graph_mode</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Whether to use CUDA graph mode."</span><span class="p">)</span>
|
||
|
||
<span class="n">cuda_graph_cache_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Number of cuda graphs to be cached in the runtime. The larger the cache, the better the perf, but more GPU memory is consumed."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="n">res</span> <span class="o">=</span> <span class="n">_ExtendedRuntimePerfKnobConfig</span><span class="p">(</span>
|
||
<span class="n">multi_block_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_block_mode</span><span class="p">,</span>
|
||
<span class="n">enable_context_fmha_fp32_acc</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_context_fmha_fp32_acc</span><span class="p">)</span>
|
||
<span class="n">res</span><span class="o">.</span><span class="n">cuda_graph_mode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_mode</span>
|
||
<span class="n">res</span><span class="o">.</span><span class="n">cuda_graph_cache_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_cache_size</span>
|
||
<span class="k">return</span> <span class="n">res</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="CacheTransceiverConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.CacheTransceiverConfig">[docs]</a>
|
||
<span class="nd">@PybindMirror</span><span class="o">.</span><span class="n">mirror_pybind_fields</span><span class="p">(</span><span class="n">_CacheTransceiverConfig</span><span class="p">)</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">CacheTransceiverConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">,</span> <span class="n">PybindMirror</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for the cache transceiver.</span>
|
||
<span class="sd"> """</span>
|
||
|
||
<span class="n">backend</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Literal</span><span class="p">[</span><span class="s2">"DEFAULT"</span><span class="p">,</span> <span class="s2">"UCX"</span><span class="p">,</span> <span class="s2">"NIXL"</span><span class="p">,</span> <span class="s2">"MPI"</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The communication backend type to use for the cache transceiver."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_tokens_in_buffer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The max number of tokens the transfer buffer can fit."</span><span class="p">)</span>
|
||
|
||
<span class="n">kv_transfer_timeout_ms</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">gt</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">kv_transfer_sender_future_timeout_ms</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
|
||
<span class="n">gt</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Timeout in milliseconds to wait for the sender future to be ready when scheduled batch size is 0. This allows the request to be eventually cancelled by the user or because of kv_transfer_timeout_ms"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_to_pybind</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">_CacheTransceiverConfig</span><span class="p">(</span>
|
||
<span class="n">backend</span><span class="o">=</span><span class="n">_CacheTransceiverBackendType</span><span class="o">.</span><span class="n">from_string</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="p">),</span>
|
||
<span class="n">max_tokens_in_buffer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_tokens_in_buffer</span><span class="p">,</span>
|
||
<span class="n">kv_transfer_timeout_ms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_transfer_timeout_ms</span><span class="p">,</span>
|
||
<span class="n">kv_transfer_sender_future_timeout_ms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
|
||
<span class="n">kv_transfer_sender_future_timeout_ms</span><span class="p">)</span></div>
|
||
|
||
|
||
|
||
<span class="nd">@dataclass</span>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">_ModelWrapper</span><span class="p">:</span>
|
||
<span class="n">model</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">__post_init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"model should be provided."</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span>
|
||
<span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">)),</span> <span class="sa">f</span><span class="s2">"Invalid model: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="n">model_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_dir</span><span class="o">.</span><span class="n">exists</span><span class="p">()</span> <span class="ow">and</span> <span class="n">model_dir</span><span class="o">.</span><span class="n">is_dir</span><span class="p">():</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model_dir</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_hub_model</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_local_model</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">is_local_model</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="n">Path</span><span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">model_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Path</span><span class="p">:</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_local_model</span><span class="p">,</span> <span class="sa">f</span><span class="s2">"model_dir is only available for local model, </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="si">}</span><span class="s2">."</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span>
|
||
|
||
<span class="nd">@model_dir</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">model_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_dir</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]):</span>
|
||
<span class="n">model_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span>
|
||
<span class="k">assert</span> <span class="n">model_dir</span><span class="o">.</span><span class="n">exists</span><span class="p">()</span> <span class="ow">and</span> <span class="n">model_dir</span><span class="o">.</span><span class="n">is_dir</span><span class="p">(</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"model_dir is not a valid path, </span><span class="si">{</span><span class="n">model_dir</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model_dir</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">model_name</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="k">else</span> <span class="kc">None</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">BaseLlmArgs</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Base class for both TorchLlmArgs and TrtLlmArgs. It contains all the arguments that are common to both.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">model_config</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s2">"arbitrary_types_allowed"</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
|
||
<span class="s2">"extra"</span><span class="p">:</span> <span class="s2">"forbid"</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
|
||
<span class="c1"># Explicit arguments</span>
|
||
<span class="n">model</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The path to the model checkpoint or the model name from the Hugging Face Hub."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">tokenizer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span>
|
||
<span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">,</span> <span class="n">TokenizerBase</span><span class="p">,</span> <span class="n">PreTrainedTokenizerBase</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The path to the tokenizer checkpoint or the tokenizer name from the Hugging Face Hub."</span><span class="p">,</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="n">tokenizer_mode</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s1">'auto'</span><span class="p">,</span> <span class="s1">'slow'</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="s1">'auto'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The mode to initialize the tokenizer."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span><span class="s2">"type"</span><span class="p">:</span> <span class="s2">"Literal['auto', 'slow']"</span><span class="p">})</span>
|
||
|
||
<span class="n">skip_tokenizer_init</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Whether to skip the tokenizer initialization."</span><span class="p">)</span>
|
||
|
||
<span class="n">trust_remote_code</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Whether to trust the remote code."</span><span class="p">)</span>
|
||
|
||
<span class="n">tensor_parallel_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The tensor parallel size."</span><span class="p">)</span>
|
||
|
||
<span class="n">dtype</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The data type to use for the model."</span><span class="p">)</span>
|
||
|
||
<span class="n">revision</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The revision to use for the model."</span><span class="p">)</span>
|
||
|
||
<span class="n">tokenizer_revision</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The revision to use for the tokenizer."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Below are all remaining arguments</span>
|
||
|
||
<span class="n">pipeline_parallel_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The pipeline parallel size."</span><span class="p">)</span>
|
||
|
||
<span class="n">context_parallel_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The context parallel size."</span><span class="p">)</span>
|
||
|
||
<span class="n">gpus_per_node</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The number of GPUs per node."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">,</span>
|
||
<span class="n">validate_default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">moe_cluster_parallel_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The cluster parallel size for MoE models's expert weights."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">moe_tensor_parallel_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The tensor parallel size for MoE models's expert weights."</span><span class="p">)</span>
|
||
|
||
<span class="n">moe_expert_parallel_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The expert parallel size for MoE models's expert weights."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_attention_dp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable attention data parallel."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_lm_head_tp_in_adp</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable LM head TP in attention dp."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">pp_partition</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Pipeline parallel partition, a list of each rank's layer number."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">cp_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Context parallel config."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">load_format</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s1">'auto'</span><span class="p">,</span> <span class="s1">'dummy'</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="s1">'auto'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The format to load the model."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span><span class="s2">"type"</span><span class="p">:</span> <span class="s2">"Literal['auto', 'dummy']"</span><span class="p">})</span>
|
||
|
||
<span class="n">fail_fast_on_attention_window_too_large</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Fail fast when attention window is too large to fit even a single sequence in the KV cache."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># LoRA arguments</span>
|
||
<span class="n">enable_lora</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Enable LoRA."</span><span class="p">)</span>
|
||
|
||
<span class="n">lora_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">LoraConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"LoRA configuration for the model."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Several options from ExecutorConfig, expanded here for less hierarchy</span>
|
||
<span class="n">kv_cache_config</span><span class="p">:</span> <span class="n">KvCacheConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">KvCacheConfig</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"KV cache config."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_chunked_prefill</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable chunked prefill."</span><span class="p">)</span>
|
||
|
||
<span class="n">guided_decoding_backend</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Literal</span><span class="p">[</span><span class="s2">"xgrammar"</span><span class="p">,</span> <span class="s2">"llguidance"</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Guided decoding backend. llguidance is supported in PyTorch backend only."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">batched_logits_processor</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">object</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Batched logits processor."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span>
|
||
<span class="s2">"type"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"Optional[</span><span class="si">{</span><span class="n">get_type_repr</span><span class="p">(</span><span class="n">BatchedLogitsProcessor</span><span class="p">)</span><span class="si">}</span><span class="s2">]"</span>
|
||
<span class="p">})</span>
|
||
|
||
<span class="n">iter_stats_max_iterations</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The maximum number of iterations for iter stats."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">request_stats_max_iterations</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The maximum number of iterations for request stats."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># A handful of options from PretrainedConfig</span>
|
||
<span class="n">peft_cache_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">PeftCacheConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"PEFT cache config."</span><span class="p">,</span> <span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">scheduler_config</span><span class="p">:</span> <span class="n">SchedulerConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">SchedulerConfig</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Scheduler config."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">cache_transceiver_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">CacheTransceiverConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Cache transceiver config."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Sparse attention config</span>
|
||
<span class="n">sparse_attention_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">SparseAttentionConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Sparse attention config."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Speculative decoding parameters</span>
|
||
<span class="n">speculative_config</span><span class="p">:</span> <span class="n">SpeculativeConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Speculative decoding config."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_batch_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The maximum batch size."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># generation constraints</span>
|
||
<span class="n">max_input_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The maximum input length."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_seq_len</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The maximum sequence length."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_beam_width</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The maximum beam width."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_num_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">8192</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The maximum number of tokens."</span><span class="p">)</span>
|
||
|
||
<span class="n">gather_generation_logits</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Gather generation logits."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># private fields those are unstable and just for internal use</span>
|
||
<span class="n">num_postprocess_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The number of processes used for postprocessing the generated tokens, including detokenization."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">postprocess_tokenizer_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The path to the tokenizer directory for postprocessing."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">reasoning_parser</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The parser to separate reasoning content from output."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># TODO[Superjomn]: To deprecate this config.</span>
|
||
<span class="n">decoding_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">object</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The decoding config."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span>
|
||
<span class="s2">"type"</span><span class="p">:</span> <span class="s2">"Optional[tensorrt_llm.llmapi.llm_args.DecodingConfig]"</span>
|
||
<span class="p">},</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"deprecated"</span><span class="p">,</span>
|
||
<span class="n">deprecated</span><span class="o">=</span><span class="s2">"Use speculative_config instead."</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">mpi_session</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">object</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The optional MPI session to use for this LLM instance."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span><span class="s2">"type"</span><span class="p">:</span> <span class="s2">"Optional[MpiSession]"</span><span class="p">},</span>
|
||
<span class="n">exclude</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">alias</span><span class="o">=</span><span class="s2">"_mpi_session"</span><span class="p">)</span>
|
||
|
||
<span class="n">otlp_traces_endpoint</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Target URL to which OpenTelemetry traces will be sent."</span><span class="p">,</span>
|
||
<span class="n">alias</span><span class="o">=</span><span class="s2">"otlp_traces_endpoint"</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">backend</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The backend to use for this LLM instance."</span><span class="p">,</span>
|
||
<span class="n">exclude_json_schema</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># hide from API references</span>
|
||
<span class="n">validate_default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"deprecated"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">return_perf_metrics</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Return perf metrics."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">orchestrator_type</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Literal</span><span class="p">[</span><span class="s2">"rpc"</span><span class="p">,</span> <span class="s2">"ray"</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The orchestrator type to use. Defaults to None, which uses MPI."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">_parallel_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">_ParallelConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">_model_format</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">_ModelFormatKind</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">_speculative_model</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">_speculative_model_format</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">_ModelFormatKind</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">parallel_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">_ParallelConfig</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_parallel_config</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">model_format</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">_ModelFormatKind</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_model_format</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">speculative_model_dir</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Optional</span><span class="p">[</span><span class="n">_ModelFormatKind</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">speculative_model_format</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">_ModelFormatKind</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_format</span>
|
||
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">from_kwargs</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-></span> <span class="s2">"BaseLlmArgs"</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Create `LlmArgs` instance from kwargs.</span>
|
||
|
||
<span class="sd"> Args:</span>
|
||
<span class="sd"> kwargs (Any): Arguments passed to `LlmArgs` constructor.</span>
|
||
|
||
<span class="sd"> Returns:</span>
|
||
<span class="sd"> tensorrt_llm.llmapi.llm_utils.BaseLlmArgs: The `BaseLlmArgs` instance.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">kwargs</span> <span class="o">=</span> <span class="n">BaseLlmArgs</span><span class="o">.</span><span class="n">_check_consistency</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">kwargs</span><span class="p">))</span>
|
||
<span class="n">ret</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">ret</span>
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_check_consistency</span><span class="p">(</span><span class="n">kwargs_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]:</span>
|
||
<span class="c1"># max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during</span>
|
||
<span class="c1"># generation</span>
|
||
<span class="n">black_list</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="s2">"max_beam_width"</span><span class="p">])</span>
|
||
<span class="n">executor_config_attrs</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span>
|
||
<span class="n">attr</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="nb">dir</span><span class="p">(</span><span class="n">_ExecutorConfig</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">attr</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">'_'</span><span class="p">)</span>
|
||
<span class="ow">and</span> <span class="nb">callable</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="n">_ExecutorConfig</span><span class="p">,</span> <span class="n">attr</span><span class="p">)))</span>
|
||
<span class="n">executor_config_attrs</span> <span class="o">-=</span> <span class="n">black_list</span>
|
||
<span class="n">llm_args_attr</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">BaseLlmArgs</span><span class="o">.</span><span class="n">model_fields</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
<span class="c1"># NOTE: When cpp ExecutorConfig add new options, please add the new options into `LlmArgs` with docs as well</span>
|
||
<span class="c1"># ASK chunweiy for help if you are not sure about the new options.</span>
|
||
<span class="k">assert</span> <span class="n">executor_config_attrs</span><span class="o">.</span><span class="n">issubset</span><span class="p">(</span>
|
||
<span class="n">llm_args_attr</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"New options found in underlying ExecutorConfig: </span><span class="si">{</span><span class="n">llm_args_attr</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">executor_config_attrs</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">return</span> <span class="n">kwargs_dict</span>
|
||
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s2">"dtype"</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_dtype</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">info</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">get_device_properties</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">major</span> <span class="o"><</span> <span class="mi">8</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o">==</span> <span class="s1">'auto'</span><span class="p">:</span>
|
||
<span class="n">v</span> <span class="o">=</span> <span class="s1">'float16'</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o">==</span> <span class="s1">'bfloat16'</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Pre SM 80 GPUs do not support bfloat16"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s2">"gpus_per_node"</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'before'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_gpus_per_node</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">info</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Using default gpus_per_node: </span><span class="si">{</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">device_count</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="n">v</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">device_count</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s2">"model"</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_model</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">info</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">)):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid model: </span><span class="si">{</span><span class="n">v</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_parallel_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">moe_cluster_parallel_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">moe_cluster_parallel_size</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">moe_tensor_parallel_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">moe_tensor_parallel_size</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">moe_expert_parallel_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">moe_expert_parallel_size</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_parallel_config</span> <span class="o">=</span> <span class="n">_ParallelConfig</span><span class="p">(</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tensor_parallel_size</span><span class="p">,</span>
|
||
<span class="n">pp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pipeline_parallel_size</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">context_parallel_size</span><span class="p">,</span>
|
||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||
<span class="n">moe_cluster_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_cluster_parallel_size</span><span class="p">,</span>
|
||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_tensor_parallel_size</span><span class="p">,</span>
|
||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_expert_parallel_size</span><span class="p">,</span>
|
||
<span class="n">enable_attention_dp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_attention_dp</span><span class="p">,</span>
|
||
<span class="n">enable_lm_head_tp_in_adp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_lm_head_tp_in_adp</span><span class="p">,</span>
|
||
<span class="n">pp_partition</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">pp_partition</span><span class="p">,</span>
|
||
<span class="n">cp_config</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cp_config</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">set_default_max_input_len</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">=</span> <span class="mi">1024</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_and_init_tokenizer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Initialize tokenizer based on configuration."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_tokenizer_init</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer_factory</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="p">,</span>
|
||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">trust_remote_code</span><span class="p">,</span>
|
||
<span class="n">use_fast</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tokenizer_mode</span> <span class="o">!=</span> <span class="s1">'slow'</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_model_format_misc</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">'''</span>
|
||
<span class="sd"> Load the model format, and do the following:</span>
|
||
|
||
<span class="sd"> 1. Load the build_config if got an engine.</span>
|
||
<span class="sd"> 2. Load the parallel_config if got a checkpoint.</span>
|
||
<span class="sd"> '''</span>
|
||
<span class="n">model_obj</span> <span class="o">=</span> <span class="n">_ModelWrapper</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_obj</span><span class="o">.</span><span class="n">is_local_model</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span>
|
||
<span class="p">]:</span>
|
||
<span class="c1"># Load parallel_config from the engine.</span>
|
||
<span class="n">model_format</span> <span class="o">=</span> <span class="n">get_model_format</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">,</span> <span class="n">trust_remote_code</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">trust_remote_code</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">model_format</span> <span class="ow">is</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">TLLM_ENGINE</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"The build_config is ignored for model format of TLLM_ENGINE."</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_load_config_from_engine</span><span class="p">(</span><span class="n">model_obj</span><span class="o">.</span><span class="n">model_dir</span><span class="p">)</span>
|
||
<span class="n">runtime_defaults</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pretrained_config</span><span class="o">.</span><span class="n">runtime_defaults</span>
|
||
<span class="k">if</span> <span class="n">runtime_defaults</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span><span class="o">.</span><span class="n">fill_empty_fields_from_runtime_defaults</span><span class="p">(</span>
|
||
<span class="n">runtime_defaults</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Load parallel_config from the checkpoint.</span>
|
||
<span class="k">elif</span> <span class="n">model_format</span> <span class="ow">is</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">TLLM_CKPT</span><span class="p">:</span>
|
||
<span class="c1"># We need to create a temporary instance to call _load_config_from_ckpt</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_load_config_from_ckpt</span><span class="p">(</span><span class="n">model_obj</span><span class="o">.</span><span class="n">model_dir</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">model_format</span> <span class="o">=</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">HF</span>
|
||
|
||
<span class="c1"># Store the model format in the values</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_model_format</span> <span class="o">=</span> <span class="n">model_format</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">init_build_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Creating a default BuildConfig if none is provided</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">build_config</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">"build_config"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">build_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">kwargs</span> <span class="o">=</span> <span class="p">{}</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s2">"max_batch_size"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s2">"max_num_tokens"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s2">"max_seq_len"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s2">"max_beam_width"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">:</span>
|
||
<span class="n">kwargs</span><span class="p">[</span><span class="s2">"max_input_len"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="o">=</span> <span class="n">BuildConfig</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">set_runtime_knobs_from_build_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="c1"># TODO: remove this after PyT become default to adapt PyT with build_config as input</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"build_config is not initialized"</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="o">==</span> <span class="s2">"pytorch"</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">:</span>
|
||
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="s2">"max_batch_size"</span><span class="p">,</span> <span class="s2">"max_num_tokens"</span><span class="p">,</span> <span class="s2">"max_seq_len"</span><span class="p">,</span>
|
||
<span class="s2">"max_input_len"</span><span class="p">,</span> <span class="s2">"max_beam_width"</span>
|
||
<span class="p">]:</span>
|
||
<span class="k">if</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span> <span class="n">key</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="p">(</span><span class="n">v</span> <span class="o">:=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span>
|
||
<span class="kc">None</span><span class="p">))</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">v</span> <span class="o">!=</span> <span class="nb">getattr</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"overriding </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> from build_config"</span><span class="p">)</span>
|
||
<span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span> <span class="n">key</span><span class="p">))</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_runtime_args</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_batch_size [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s2">] should be less than or equal to max_num_tokens [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="si">}</span><span class="s2">]"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_build_config_with_runtime_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="c1"># Note: max_batch_size and max_num_tokens in LlmArgs are for runtime,</span>
|
||
<span class="c1"># which will be passed to the C++ Executor API, overwriting the values</span>
|
||
<span class="c1"># from an built engine. In order to set build configuration, it is</span>
|
||
<span class="c1"># recommended to use build_config instead.</span>
|
||
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">,</span> <span class="n">BuildConfig</span>
|
||
<span class="p">),</span> <span class="sa">f</span><span class="s2">"build_config is not initialized: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="si">}</span><span class="s2">"</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_batch_size [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s2">] is overridden by build_config.max_batch_size [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s2">] in build_config"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_num_tokens [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="si">}</span><span class="s2">] is overridden by build_config.max_num_tokens [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="si">}</span><span class="s2">] in build_config"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_seq_len [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2">] is overridden by build_config.max_seq_len [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2">] in build_config"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_beam_width [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span><span class="si">}</span><span class="s2">] is overridden by build_config.max_beam_width [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span><span class="si">}</span><span class="s2">] in build_config"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"max_input_len [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s2">] is overridden by build_config.max_input_len [</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_input_len</span><span class="si">}</span><span class="s2">] in build_config"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_build_config_remaining</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="n">is_trt_llm_args</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">TrtLlmArgs</span><span class="p">)</span>
|
||
|
||
<span class="c1"># TODO: remove the checker when manage weights support all data types</span>
|
||
<span class="k">if</span> <span class="n">is_trt_llm_args</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">fast_build</span> <span class="ow">and</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">quant_config</span><span class="o">.</span><span class="n">quant_algo</span>
|
||
<span class="ow">is</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">manage_weights</span> <span class="o">=</span> <span class="kc">True</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">world_size</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">nccl_plugin</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_lora</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="o">!=</span> <span class="s1">'pytorch'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">plugin_config</span><span class="o">.</span><span class="n">lora_plugin</span> <span class="o">=</span> <span class="s1">'auto'</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">max_lora_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">max_lora_rank</span>
|
||
|
||
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
|
||
<span class="s1">'enable_prompt_adapter'</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_prompt_adapter</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_prompt_embedding_table_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_prompt_adapter_token</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_beam_width</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_speculative_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">supports_backend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Speculation type </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">decoding_type</span><span class="si">}</span><span class="s2"> does not "</span>
|
||
<span class="sa">f</span><span class="s2">"support backend </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Below, we only need to set speculative_decoding_mode/decoding_config for speculation</span>
|
||
<span class="c1"># on the TRT backend.</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">LookaheadDecodingConfig</span><span class="p">):</span>
|
||
<span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">calculate_speculative_resource</span><span class="p">(</span>
|
||
<span class="p">)[</span><span class="mi">2</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">LOOKAHEAD_DECODING</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span><span class="p">,</span> <span class="n">max_draft_len</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoding_config</span> <span class="o">=</span> <span class="n">DecodingConfig</span><span class="p">(</span>
|
||
<span class="n">decoding_mode</span><span class="o">=</span><span class="n">DecodingMode</span><span class="o">.</span><span class="n">Lookahead</span><span class="p">(),</span>
|
||
<span class="n">lookahead_decoding_config</span><span class="o">=</span><span class="n">PybindMirror</span><span class="o">.</span><span class="n">maybe_to_pybind</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">))</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">MedusaDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">MEDUSA</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoding_config</span> <span class="o">=</span> <span class="n">DecodingConfig</span><span class="p">(</span>
|
||
<span class="n">decoding_mode</span><span class="o">=</span><span class="n">DecodingMode</span><span class="o">.</span><span class="n">Medusa</span><span class="p">(),</span>
|
||
<span class="n">medusa_choices</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">medusa_choices</span><span class="p">)</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">EagleDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">speculative_model_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Path to EAGLE3 weights must be specified."</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">EAGLE</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span><span class="p">]:</span>
|
||
<span class="n">eagle_config</span> <span class="o">=</span> <span class="n">_EagleConfig</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">eagle_choices</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">greedy_sampling</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">posterior_threshold</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">use_dynamic_tree</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">dynamic_tree_max_topK</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoding_config</span> <span class="o">=</span> <span class="n">DecodingConfig</span><span class="p">(</span>
|
||
<span class="n">decoding_mode</span><span class="o">=</span><span class="n">DecodingMode</span><span class="o">.</span><span class="n">Eagle</span><span class="p">(),</span>
|
||
<span class="n">eagle_config</span><span class="o">=</span><span class="n">eagle_config</span><span class="p">)</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">NGramDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_matching_ngram_size</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">NGRAM</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">DraftTargetDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">]</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">speculative_model_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">"Path to draft model must be specified."</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">DRAFT_TOKENS_EXTERNAL</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">MTPDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">num_nextn_predict_layers</span> <span class="o">></span> <span class="mi">0</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">num_nextn_predict_layers</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span>
|
||
<span class="n">UserProvidedDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">USER_PROVIDED</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span> <span class="n">AutoDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span><span class="p">]</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">AUTO</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span>
|
||
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span>
|
||
<span class="n">SaveHiddenStatesDecodingConfig</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'pytorch'</span><span class="p">]</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">disable_overlap_scheduler</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_config</span> <span class="o">=</span> <span class="kc">None</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">speculative_decoding_mode</span> <span class="o">=</span> <span class="n">SpeculativeDecodingMode</span><span class="o">.</span><span class="n">SAVE_HIDDEN_STATES</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="o">.</span><span class="n">max_draft_len</span> <span class="o">=</span> <span class="mi">1</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Unrecognized speculative config type </span><span class="si">{</span><span class="nb">type</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">decoding_config</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">speculative_config</span><span class="p">,</span>
|
||
<span class="s2">"speculative_model_dir"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
<span class="n">speculative_model_obj</span> <span class="o">=</span> <span class="n">_ModelWrapper</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model</span>
|
||
<span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model</span> <span class="ow">and</span> <span class="n">speculative_model_obj</span><span class="o">.</span><span class="n">is_local_model</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_speculative_model_format</span> <span class="o">=</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">HF</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_lora_config_consistency</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="c1"># TODO [TRTLLM-5173]</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"lora_dir is empty, so custom embedding or lm head will not be applied."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_lora</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">backend</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="s1">'pytorch'</span><span class="p">,</span> <span class="s1">'_autodeploy'</span>
|
||
<span class="p">]:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"enable_lora is ignored when lora_config is provided for </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">backend</span><span class="si">}</span><span class="s2"> backend."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_dir</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_target_modules</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"Both lora_dir and lora_target_modules are empty, so all LoRA modules will be expected. "</span>
|
||
<span class="s2">"This will lead to serious memory consumption. Please provide either lora_dir or lora_target_modules if this behavior is not what you expect."</span>
|
||
<span class="p">)</span>
|
||
<span class="n">default_trtllm_modules_to_hf_modules</span> <span class="o">=</span> <span class="n">get_default_trtllm_modules_to_hf_modules</span><span class="p">(</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">lora_config</span><span class="o">.</span><span class="n">lora_target_modules</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span>
|
||
<span class="n">default_trtllm_modules_to_hf_modules</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_peft_cache_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">peft_cache_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">peft_cache_config</span><span class="o">.</span><span class="n">lora_prefetch_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"lora_prefetch_dir was set to '</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">peft_cache_config</span><span class="o">.</span><span class="n">lora_prefetch_dir</span><span class="si">}</span><span class="s2">' "</span>
|
||
<span class="s2">"while LoRA prefetch is not supported"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_load_config_from_engine</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">engine_dir</span><span class="p">:</span> <span class="n">Path</span><span class="p">):</span>
|
||
<span class="n">engine_config</span> <span class="o">=</span> <span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span><span class="n">engine_dir</span> <span class="o">/</span> <span class="s2">"config.json"</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_pretrained_config</span> <span class="o">=</span> <span class="n">engine_config</span><span class="o">.</span><span class="n">pretrained_config</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">build_config</span> <span class="o">=</span> <span class="n">engine_config</span><span class="o">.</span><span class="n">build_config</span>
|
||
|
||
<span class="c1"># load and check parallel_config</span>
|
||
<span class="n">mapping</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_pretrained_config</span><span class="o">.</span><span class="n">mapping</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"tp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span><span class="si">}</span><span class="s2"> is not consistent with the engine's tp_size </span><span class="si">{</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"pp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2"> is not consistent with the engine's pp_size </span><span class="si">{</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"cp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2"> is not consistent with the engine's cp_size </span><span class="si">{</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_parallel_config</span> <span class="o">=</span> <span class="n">_ParallelConfig</span><span class="p">(</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="n">pp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span><span class="p">,</span>
|
||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||
<span class="n">moe_cluster_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_cluster_size</span><span class="p">,</span>
|
||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_ep_size</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_load_config_from_ckpt</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="p">:</span> <span class="n">Path</span><span class="p">):</span>
|
||
<span class="n">pretrained_config</span> <span class="o">=</span> <span class="n">PretrainedConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span><span class="n">ckpt_dir</span> <span class="o">/</span>
|
||
<span class="s2">"config.json"</span><span class="p">)</span>
|
||
<span class="n">tp_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">tp_size</span>
|
||
<span class="n">pp_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">pp_size</span>
|
||
<span class="n">cp_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">cp_size</span>
|
||
<span class="n">moe_cluster_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_cluster_size</span>
|
||
<span class="n">moe_tp_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_tp_size</span>
|
||
<span class="n">moe_ep_size</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">moe_ep_size</span>
|
||
<span class="n">gpus_per_node</span> <span class="o">=</span> <span class="n">pretrained_config</span><span class="o">.</span><span class="n">mapping</span><span class="o">.</span><span class="n">gpus_per_node</span>
|
||
<span class="c1"># load parallel_config</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span> <span class="o">!=</span> <span class="n">tp_size</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"tp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">tp_size</span><span class="si">}</span><span class="s2"> is not consistent with the checkpoint's tp_size </span><span class="si">{</span><span class="n">tp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span> <span class="o">!=</span> <span class="n">pp_size</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"pp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">pp_size</span><span class="si">}</span><span class="s2"> is not consistent with the checkpoint's pp_size </span><span class="si">{</span><span class="n">pp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span> <span class="o">!=</span> <span class="n">cp_size</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"cp_size </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">parallel_config</span><span class="o">.</span><span class="n">cp_size</span><span class="si">}</span><span class="s2"> is not consistent with the checkpoint's cp_size </span><span class="si">{</span><span class="n">cp_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_parallel_config</span> <span class="o">=</span> <span class="n">_ParallelConfig</span><span class="p">(</span>
|
||
<span class="n">tp_size</span><span class="o">=</span><span class="n">tp_size</span><span class="p">,</span>
|
||
<span class="n">pp_size</span><span class="o">=</span><span class="n">pp_size</span><span class="p">,</span>
|
||
<span class="n">cp_size</span><span class="o">=</span><span class="n">cp_size</span><span class="p">,</span>
|
||
<span class="n">gpus_per_node</span><span class="o">=</span><span class="n">gpus_per_node</span><span class="p">,</span>
|
||
<span class="n">moe_cluster_size</span><span class="o">=</span><span class="n">moe_cluster_size</span><span class="p">,</span>
|
||
<span class="n">moe_tp_size</span><span class="o">=</span><span class="n">moe_tp_size</span><span class="p">,</span>
|
||
<span class="n">moe_ep_size</span><span class="o">=</span><span class="n">moe_ep_size</span><span class="p">)</span>
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_runtime_sizes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_beam_width</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_num_tokens</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_seq_len</span><span class="p">,</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">TrtLlmArgs</span><span class="p">(</span><span class="n">BaseLlmArgs</span><span class="p">):</span>
|
||
<span class="n">enable_tqdm</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable tqdm for progress bar."</span><span class="p">)</span>
|
||
|
||
<span class="n">workspace</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The workspace for the model."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Once set, the model will reuse the build_cache</span>
|
||
<span class="n">enable_build_cache</span><span class="p">:</span> <span class="nb">object</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable build cache."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span>
|
||
<span class="s2">"type"</span><span class="p">:</span> <span class="sa">f</span><span class="s2">"Union[</span><span class="si">{</span><span class="n">get_type_repr</span><span class="p">(</span><span class="n">BuildCacheConfig</span><span class="p">)</span><span class="si">}</span><span class="s2">, bool]"</span>
|
||
<span class="p">})</span>
|
||
|
||
<span class="n">extended_runtime_perf_knob_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span>
|
||
<span class="n">ExtendedRuntimePerfKnobConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Extended runtime perf knob config."</span><span class="p">)</span>
|
||
|
||
<span class="n">calib_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">CalibConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Calibration config."</span><span class="p">,</span> <span class="n">validate_default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Quantization and calibration configurations</span>
|
||
<span class="n">quant_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">QuantConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Quantization config."</span><span class="p">,</span> <span class="n">validate_default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
|
||
<span class="n">embedding_parallel_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="s1">'SHARDING_ALONG_VOCAB'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The embedding parallel mode."</span><span class="p">)</span>
|
||
|
||
<span class="n">fast_build</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Enable fast build."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># BuildConfig is introduced to give users a familiar interface to configure the model building.</span>
|
||
<span class="n">build_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">BuildConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Build config."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Prompt adapter arguments</span>
|
||
<span class="n">enable_prompt_adapter</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable prompt adapter."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_prompt_adapter_token</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"The maximum number of prompt adapter tokens."</span><span class="p">)</span>
|
||
|
||
<span class="n">batching_type</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">BatchingType</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Batching type."</span><span class="p">)</span>
|
||
|
||
<span class="n">normalize_log_probs</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Normalize log probabilities."</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Private attributes</span>
|
||
<span class="c1"># This is used to hold the options for convert_checkpoint</span>
|
||
<span class="n">_convert_checkpoint_options</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs.init_calib_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs.init_calib_config">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'calib_config'</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'before'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">init_calib_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">CalibConfig</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs.validate_quant_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs.validate_quant_config">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s2">"quant_config"</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'before'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_quant_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">info</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">v</span> <span class="o">=</span> <span class="n">QuantConfig</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs.setup_embedding_parallel_mode">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs.setup_embedding_parallel_mode">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">setup_embedding_parallel_mode</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_parallel_mode</span> <span class="o">==</span> <span class="s1">'NONE'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span><span class="p">[</span><span class="s1">'use_parallel_embedding'</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_parallel_mode</span> <span class="o">==</span> <span class="s1">'SHARDING_ALONG_VOCAB'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span><span class="p">[</span><span class="s1">'use_parallel_embedding'</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span><span class="p">[</span><span class="s1">'embedding_sharding_dim'</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_parallel_mode</span> <span class="o">==</span> <span class="s1">'SHARDING_ALONG_HIDDEN'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span><span class="p">[</span><span class="s1">'use_parallel_embedding'</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_convert_checkpoint_options</span><span class="p">[</span><span class="s1">'embedding_sharding_dim'</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="c1"># No else clause needed since validation already happened</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs.validate_enable_build_cache">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs.validate_enable_build_cache">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_enable_build_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span> <span class="o">=</span> <span class="n">BuildCacheConfig</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span><span class="p">,</span> <span class="nb">bool</span><span class="p">)</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span><span class="p">,</span> <span class="n">BuildCacheConfig</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Invalid build_cache_config: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">enable_build_cache</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TrtLlmArgs.validate_kv_cache_dtype">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TrtLlmArgs.validate_kv_cache_dtype">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_kv_cache_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"auto"</span><span class="p">,</span> <span class="s2">"KvCacheConfig.dtype is not supported by the TensorRT backend."</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">LoadFormat</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
|
||
<span class="n">AUTO</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="c1"># Initialize all weights randomly.</span>
|
||
<span class="n">DUMMY</span> <span class="o">=</span> <span class="mi">1</span>
|
||
<span class="c1"># Only load the multimodal(vision) encoder weights</span>
|
||
<span class="n">VISION_ONLY</span> <span class="o">=</span> <span class="mi">2</span>
|
||
|
||
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">SamplerType</span><span class="p">(</span><span class="n">StrEnum</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Enum for sampler type options."""</span>
|
||
<span class="n">TRTLLMSampler</span> <span class="o">=</span> <span class="s2">"TRTLLMSampler"</span>
|
||
<span class="n">TorchSampler</span> <span class="o">=</span> <span class="s2">"TorchSampler"</span>
|
||
<span class="n">auto</span> <span class="o">=</span> <span class="s2">"auto"</span>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchCompileConfig">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchCompileConfig">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">TorchCompileConfig</span><span class="p">(</span><span class="n">StrictBaseModel</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""</span>
|
||
<span class="sd"> Configuration for torch.compile.</span>
|
||
<span class="sd"> """</span>
|
||
<span class="n">enable_fullgraph</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable full graph compilation in torch.compile."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_inductor</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Enable inductor backend in torch.compile."</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_piecewise_cuda_graph</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable piecewise CUDA graph in torch.compile."</span><span class="p">)</span>
|
||
|
||
<span class="n">capture_num_tokens</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"List of num of tokens to capture the piecewise CUDA graph for. If not provided, the number of tokens will be the same as cuda_graph_config.batch_sizes."</span>
|
||
<span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="TorchCompileConfig.validate_capture_num_tokens">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchCompileConfig.validate_capture_num_tokens">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'capture_num_tokens'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_capture_num_tokens</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">t</span> <span class="o"><=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"capture_num_tokens must contain positive ints."</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">v</span><span class="p">),</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>
|
||
|
||
|
||
<span class="n">enable_userbuffers</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"When torch compile is enabled, userbuffers is enabled by default."</span><span class="p">)</span>
|
||
|
||
<span class="n">max_num_streams</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The maximum number of CUDA streams to use for torch.compile."</span><span class="p">)</span>
|
||
|
||
<div class="viewcode-block" id="TorchCompileConfig.validate_torch_compile_max_num_streams">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchCompileConfig.validate_torch_compile_max_num_streams">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'max_num_streams'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_torch_compile_max_num_streams</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="w"> </span><span class="sd">"""Validate torch_compile_config.max_num_streams >= 1."""</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="o"><</span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"torch_compile_config.max_num_streams must be >= 1"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<span class="nd">@staticmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">_generate_capture_num_tokens</span><span class="p">()</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="mi">2</span><span class="o">**</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">8</span><span class="p">)]</span> <span class="o">+</span> <span class="p">[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">3073</span><span class="p">,</span> <span class="mi">256</span><span class="p">)]</span></div>
|
||
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs">[docs]</a>
|
||
<span class="k">class</span><span class="w"> </span><span class="nc">TorchLlmArgs</span><span class="p">(</span><span class="n">BaseLlmArgs</span><span class="p">):</span>
|
||
<span class="c1"># Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs</span>
|
||
<span class="n">build_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">BuildConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Build config."</span><span class="p">,</span>
|
||
<span class="n">exclude_from_json</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"deprecated"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># PyTorch backend specific configurations</span>
|
||
<span class="n">garbage_collection_gen0_threshold</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">20000</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Threshold for Python garbage collection of generation 0 objects."</span>
|
||
<span class="s2">"Lower values trigger more frequent garbage collection."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">cuda_graph_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">CudaGraphConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default_factory</span><span class="o">=</span><span class="n">CudaGraphConfig</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"CUDA graph config.If true, use CUDA graphs for decoding. </span><span class="se">\</span>
|
||
<span class="s2"> CUDA graphs are only created for the batch sizes in cuda_graph_config.batch_sizes, </span><span class="se">\</span>
|
||
<span class="s2"> and are enabled for batches that consist of decoding requests *only* </span><span class="se">\</span>
|
||
<span class="s2"> (the reason is that it's hard to capture a single graph with prefill requests </span><span class="se">\</span>
|
||
<span class="s2"> since the input shapes are a function of the sequence lengths).</span><span class="se">\</span>
|
||
<span class="s2"> Note that each CUDA graph can use up to 200 MB of extra memory."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">attention_dp_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AttentionDpConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Optimized load-balancing for the DP Attention scheduler."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">disable_overlap_scheduler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Disable the overlap scheduler."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">moe_config</span><span class="p">:</span> <span class="n">MoeConfig</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">MoeConfig</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"MoE config."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">attn_backend</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s1">'TRTLLM'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Attention backend to use."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">sampler_type</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">SamplerType</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">SamplerType</span><span class="o">.</span><span class="n">auto</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_iter_perf_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Enable iteration performance statistics."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_iter_req_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"If true, enables per request stats per iteration. Must also set enable_iter_perf_stats to true to get request stats."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">print_iter_log</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Print iteration logs."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">perf_metrics_max_requests</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The maximum number of requests for perf metrics. Must also set request_perf_metrics to true to get perf metrics."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">batch_wait_timeout_ms</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">batch_wait_timeout_iters</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">batch_wait_max_tokens_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">torch_compile_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">TorchCompileConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">description</span><span class="o">=</span><span class="s2">"Torch compile config."</span><span class="p">,</span> <span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_autotuner</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Enable autotuner for all tunable ops. This flag is for debugging purposes only, and the performance may significantly degrade if set to false."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_layerwise_nvtx_marker</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"If true, enable layerwise nvtx marker."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
|
||
<span class="n">load_format</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">LoadFormat</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="n">LoadFormat</span><span class="o">.</span><span class="n">AUTO</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"How to load the model weights. By default, detect the weight type from the model checkpoint."</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">enable_min_latency</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"If true, enable min-latency mode. Currently only used for Llama4."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="c1"># TODO: make this a per-request parameter</span>
|
||
<span class="n">stream_interval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The iteration interval to create responses under the streaming mode. "</span>
|
||
<span class="s2">"Set this to a larger value when the batch size is large, which helps reduce the streaming overhead."</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">force_dynamic_quantization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"If true, force dynamic quantization. Defaults to False."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">allreduce_strategy</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Literal</span><span class="p">[</span>
|
||
<span class="s1">'AUTO'</span><span class="p">,</span> <span class="s1">'NCCL'</span><span class="p">,</span> <span class="s1">'UB'</span><span class="p">,</span> <span class="s1">'MINLATENCY'</span><span class="p">,</span> <span class="s1">'ONESHOT'</span><span class="p">,</span> <span class="s1">'TWOSHOT'</span><span class="p">,</span>
|
||
<span class="s1">'LOWPRECISION'</span><span class="p">,</span> <span class="s1">'MNNVL'</span><span class="p">,</span>
|
||
<span class="s1">'NCCL_SYMMETRIC'</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="s1">'AUTO'</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"Allreduce strategy to use."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"beta"</span><span class="p">)</span>
|
||
<span class="n">checkpoint_loader</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">object</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The checkpoint loader to use for this LLM instance. You may use a custom checkpoint loader by subclassing "</span>
|
||
<span class="s2">"`BaseCheckpointLoader` and providing an instance of the subclass here to load weights from a custom "</span>
|
||
<span class="s2">"checkpoint format.</span><span class="se">\n</span><span class="s2">"</span>
|
||
<span class="s2">"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "</span>
|
||
<span class="s2">"and the default HfCheckpointLoader will be used.</span><span class="se">\n</span><span class="s2">"</span>
|
||
<span class="s2">"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored."</span><span class="p">,</span>
|
||
<span class="n">json_schema_extra</span><span class="o">=</span><span class="p">{</span>
|
||
<span class="s2">"type"</span><span class="p">:</span>
|
||
<span class="s2">"Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]"</span>
|
||
<span class="p">},</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">checkpoint_format</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"The format of the provided checkpoint. You may use a custom checkpoint format by subclassing "</span>
|
||
<span class="s2">"`BaseCheckpointLoader` and registering it with `register_checkpoint_loader`.</span><span class="se">\n</span><span class="s2">"</span>
|
||
<span class="s2">"If neither checkpoint_format nor checkpoint_loader are provided, checkpoint_format will be set to HF "</span>
|
||
<span class="s2">"and the default HfCheckpointLoader will be used.</span><span class="se">\n</span><span class="s2">"</span>
|
||
<span class="s2">"If checkpoint_format and checkpoint_loader are both provided, checkpoint_loader will be ignored."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">kv_connector_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">KvCacheConnectorConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The config for KV cache connector."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">mm_encoder_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Only load/execute the vision encoder part of the full model. Defaults to False."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="n">ray_worker_extension_cls</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span><span class="s2">"The full worker extension class name including module path."</span>
|
||
<span class="s2">"Allows users to extend the functions of the RayGPUWorker class."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="n">enable_sleep</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span>
|
||
<span class="n">default</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
||
<span class="n">description</span><span class="o">=</span>
|
||
<span class="s2">"Enable LLM sleep feature. Sleep feature requires extra setup that may slowdown model loading."</span>
|
||
<span class="s2">"Only enable it if you intend to use this feature."</span><span class="p">,</span>
|
||
<span class="n">status</span><span class="o">=</span><span class="s2">"prototype"</span><span class="p">)</span>
|
||
|
||
<span class="c1"># PrivateVars</span>
|
||
<span class="n">_quant_config</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">QuantConfig</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="n">_disable_flash_infer_sampling</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="w"> </span><span class="sd">"""Unless this is set to False, FlashInfer.sampling is not used, even if available."""</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">quant_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">QuantConfig</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_quant_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_quant_config</span> <span class="o">=</span> <span class="n">QuantConfig</span><span class="p">()</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_quant_config</span>
|
||
|
||
<span class="nd">@quant_config</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">quant_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">QuantConfig</span><span class="p">):</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_quant_config</span> <span class="o">=</span> <span class="n">value</span>
|
||
|
||
<span class="c1"># TODO: remove backend later</span>
|
||
<div class="viewcode-block" id="TorchLlmArgs.init_backend">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.init_backend">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'backend'</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'before'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">init_backend</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="n">v</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="s1">'pytorch'</span>
|
||
<span class="k">return</span> <span class="n">v</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.convert_load_format">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.convert_load_format">[docs]</a>
|
||
<span class="nd">@field_validator</span><span class="p">(</span><span class="s1">'load_format'</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">'before'</span><span class="p">)</span>
|
||
<span class="nd">@classmethod</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">convert_load_format</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">v</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">LoadFormat</span><span class="p">):</span>
|
||
<span class="k">return</span> <span class="n">v</span>
|
||
<span class="n">load_format</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
|
||
<span class="k">if</span> <span class="n">load_format</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">LoadFormat</span><span class="o">.</span><span class="n">__members__</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Invalid LoadFormat: </span><span class="si">{</span><span class="n">v</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">LoadFormat</span><span class="p">[</span><span class="n">load_format</span><span class="p">]</span></div>
|
||
|
||
|
||
<span class="c1"># Extra resource managers to use in addition to the KV cache manager.</span>
|
||
<span class="c1"># Each manager's prepare_resources method is called before the forward pass,</span>
|
||
<span class="c1"># and update_resources() is called after the pass finishes. free_resources()</span>
|
||
<span class="c1"># is called when a request finishes. The KV cache manager is guaranteed to</span>
|
||
<span class="c1"># be invoked after all of these extra managers in all stages.</span>
|
||
<span class="n">_extra_resource_managers</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span>
|
||
<span class="nb">object</span><span class="p">]</span> <span class="o">=</span> <span class="n">PrivateAttr</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="nb">dict</span><span class="p">,</span> <span class="p">)</span>
|
||
|
||
<span class="nd">@property</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">extra_resource_managers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">object</span><span class="p">]:</span>
|
||
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_extra_resource_managers</span>
|
||
|
||
<span class="nd">@extra_resource_managers</span><span class="o">.</span><span class="n">setter</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">extra_resource_managers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">object</span><span class="p">])</span> <span class="o">-></span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">_extra_resource_managers</span> <span class="o">=</span> <span class="n">value</span>
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_stream_interval">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_stream_interval">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_stream_interval</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream_interval</span> <span class="o"><=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"stream_interval must be positive, got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">stream_interval</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_checkpoint_format">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_checkpoint_format">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_checkpoint_format</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_format</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="s2">"checkpoint_format and checkpoint_loader are both provided, "</span>
|
||
<span class="s2">"checkpoint_loader will be ignored."</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_loader</span> <span class="o">=</span> <span class="kc">None</span>
|
||
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_format</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_loader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
|
||
<span class="s2">"neither checkpoint_format nor checkpoint_loader were provided, "</span>
|
||
<span class="s2">"checkpoint_format will be set to HF."</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_format</span> <span class="o">=</span> <span class="s2">"HF"</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_load_balancer">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_load_balancer">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"after"</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_load_balancer</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"MoE load balancer config file not found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||
<span class="n">moe_load_balancer_config</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">safe_load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span> <span class="o">=</span> <span class="n">MoeLoadBalancerConfig</span><span class="p">(</span>
|
||
<span class="o">**</span><span class="n">moe_load_balancer_config</span><span class="p">)</span>
|
||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Failed to load MoE load balancer config file: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span> <span class="kn">from</span><span class="w"> </span><span class="nn">e</span>
|
||
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span> <span class="o">=</span> <span class="n">MoeLoadBalancerConfig</span><span class="p">(</span>
|
||
<span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="p">)</span>
|
||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Failed to load MoE load balancer config: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">moe_config</span><span class="o">.</span><span class="n">load_balancer</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span> <span class="kn">from</span><span class="w"> </span><span class="nn">e</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_cuda_graph_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_cuda_graph_config">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_cuda_graph_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Validate CUDA graph configuration.</span>
|
||
|
||
<span class="sd"> Ensures that:</span>
|
||
<span class="sd"> 1. If cuda_graph_config.batch_sizes is provided, cuda_graph_config.max_batch_size must be 0</span>
|
||
<span class="sd"> 2. If cuda_graph_config.batch_sizes is not provided, it is generated based on cuda_graph_config.max_batch_size</span>
|
||
<span class="sd"> 3. If both are provided, cuda_graph_config.batch_sizes must match the generated values</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_config</span>
|
||
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">:</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span> <span class="o">!=</span> <span class="n">CudaGraphConfig</span><span class="o">.</span><span class="n">_generate_cuda_graph_batch_sizes</span><span class="p">(</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">enable_padding</span><span class="p">):</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"Please don't set both cuda_graph_config.batch_sizes "</span>
|
||
<span class="s2">"and cuda_graph_config.max_batch_size.</span><span class="se">\n</span><span class="s2">"</span>
|
||
<span class="sa">f</span><span class="s2">"cuda_graph_config.batch_sizes: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_config</span><span class="o">.</span><span class="n">batch_sizes</span><span class="si">}</span><span class="s2">, "</span>
|
||
<span class="sa">f</span><span class="s2">"cuda_graph_config.max_batch_size: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">cuda_graph_config</span><span class="o">.</span><span class="n">max_batch_size</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">max_batch_size</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="ow">or</span> <span class="mi">128</span>
|
||
<span class="n">generated_sizes</span> <span class="o">=</span> <span class="n">CudaGraphConfig</span><span class="o">.</span><span class="n">_generate_cuda_graph_batch_sizes</span><span class="p">(</span>
|
||
<span class="n">max_batch_size</span><span class="p">,</span> <span class="n">config</span><span class="o">.</span><span class="n">enable_padding</span><span class="p">)</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">batch_sizes</span> <span class="o">=</span> <span class="n">generated_sizes</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">max_batch_size</span> <span class="o">=</span> <span class="n">max_batch_size</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_torch_compile_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_torch_compile_config">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_torch_compile_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_compile_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_compile_config</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">enable_piecewise_cuda_graph</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">capture_num_tokens</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="n">config</span><span class="o">.</span><span class="n">capture_num_tokens</span> <span class="o">=</span> <span class="n">TorchCompileConfig</span><span class="o">.</span><span class="n">_generate_capture_num_tokens</span><span class="p">(</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.sync_quant_config_with_kv_cache_config_dtype">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.sync_quant_config_with_kv_cache_config_dtype">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">sync_quant_config_with_kv_cache_config_dtype</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">quant_config</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s2">"auto"</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span><span class="o">.</span><span class="n">dtype</span> <span class="o">==</span> <span class="s1">'fp8'</span><span class="p">:</span>
|
||
<span class="bp">self</span><span class="o">.</span><span class="n">quant_config</span><span class="o">.</span><span class="n">kv_cache_quant_algo</span> <span class="o">=</span> <span class="n">QuantAlgo</span><span class="o">.</span><span class="n">FP8</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Cannot sync quant_config.kv_cache_quant_algo with kv_cache_config.dtype of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">kv_cache_config</span><span class="o">.</span><span class="n">dtype</span><span class="si">}</span><span class="s2">, "</span>
|
||
<span class="s2">"please update the validator"</span><span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.warn_on_unstable_feature_usage">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.warn_on_unstable_feature_usage">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">warn_on_unstable_feature_usage</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Warn on unstable feature usage."""</span>
|
||
<span class="n">set_fields</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_dump</span><span class="p">(</span><span class="n">exclude_unset</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
|
||
|
||
<span class="k">for</span> <span class="n">field_name</span> <span class="ow">in</span> <span class="n">set_fields</span><span class="p">:</span>
|
||
<span class="n">field_info</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_fields</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">field_name</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="n">field_info</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">field_info</span><span class="o">.</span><span class="n">json_schema_extra</span><span class="p">:</span>
|
||
<span class="k">continue</span>
|
||
|
||
<span class="n">status</span> <span class="o">=</span> <span class="n">field_info</span><span class="o">.</span><span class="n">json_schema_extra</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">'status'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
||
|
||
<span class="k">if</span> <span class="n">status</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">'beta'</span><span class="p">,</span> <span class="s1">'prototype'</span><span class="p">):</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"The '</span><span class="si">{</span><span class="n">field_name</span><span class="si">}</span><span class="s2">' knob is a '</span><span class="si">{</span><span class="n">status</span><span class="si">}</span><span class="s2">' feature. "</span>
|
||
<span class="s2">"It is not recommended for production use and may change or be removed."</span><span class="p">,</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_attention_dp_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_attention_dp_config">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_attention_dp_config</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Validate attention DP configuration.</span>
|
||
|
||
<span class="sd"> Ensures that:</span>
|
||
<span class="sd"> 1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater or equal to 0</span>
|
||
<span class="sd"> 2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater or equal to 0</span>
|
||
<span class="sd"> """</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dp_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="bp">self</span>
|
||
|
||
<span class="n">config</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dp_config</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">enable_balance</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">batching_wait_iters</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"attention_dp_config.batching_wait_iters must be greater or equal to 0 when enable_balance is true"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">if</span> <span class="n">config</span><span class="o">.</span><span class="n">timeout_iters</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="s2">"attention_dp_config.timeout_iters must be greater or equal to 0 when enable_balance is true"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_batch_wait_timeout_ms">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_batch_wait_timeout_ms">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_batch_wait_timeout_ms</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">"""Validate batch wait timeout."""</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_timeout_ms</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"batch_wait_timeout_ms must be greater than 0"</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_batch_wait_timeout_iters">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_batch_wait_timeout_iters">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_batch_wait_timeout_iters</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_timeout_iters</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"batch_wait_timeout_iters must be >= 0, got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_timeout_iters</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_batch_wait_max_tokens_ratio">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_batch_wait_max_tokens_ratio">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_batch_wait_max_tokens_ratio</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_max_tokens_ratio</span> <span class="o"><</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_max_tokens_ratio</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"batch_wait_max_tokens_ratio must be in range [0, 1], got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_wait_max_tokens_ratio</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.validate_ray_worker_extension_cls">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.validate_ray_worker_extension_cls">[docs]</a>
|
||
<span class="nd">@model_validator</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s1">'after'</span><span class="p">)</span>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">validate_ray_worker_extension_cls</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="s1">'TorchLlmArgs'</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ray_worker_extension_cls</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">orchestrator_type</span> <span class="o">!=</span> <span class="s2">"ray"</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"ray_worker_extension_cls is only supported with orchestrator_type='ray'"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">return</span> <span class="bp">self</span></div>
|
||
|
||
|
||
<div class="viewcode-block" id="TorchLlmArgs.get_executor_config">
|
||
<a class="viewcode-back" href="../../../llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs.get_executor_config">[docs]</a>
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_executor_config</span><span class="p">(</span>
|
||
<span class="bp">self</span><span class="p">,</span>
|
||
<span class="n">_hf_model_dir</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="n">tokenizer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">TokenizerBase</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
|
||
<span class="p">)</span> <span class="o">-></span> <span class="n">_ExecutorConfig</span><span class="p">:</span>
|
||
<span class="n">executor_config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_executor_config</span><span class="p">(</span><span class="n">_hf_model_dir</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">)</span>
|
||
<span class="n">executor_config</span><span class="o">.</span><span class="n">mm_encoder_only</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mm_encoder_only</span>
|
||
<span class="k">return</span> <span class="n">executor_config</span></div>
|
||
</div>
|
||
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">update_llm_args_with_extra_dict</span><span class="p">(</span>
|
||
<span class="n">llm_args</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span>
|
||
<span class="n">llm_args_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span>
|
||
<span class="n">extra_llm_api_options</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">:</span>
|
||
|
||
<span class="n">field_mapping</span> <span class="o">=</span> <span class="p">{</span>
|
||
<span class="s2">"quant_config"</span><span class="p">:</span> <span class="n">QuantConfig</span><span class="p">,</span>
|
||
<span class="s2">"calib_config"</span><span class="p">:</span> <span class="n">CalibConfig</span><span class="p">,</span>
|
||
<span class="s2">"build_config"</span><span class="p">:</span> <span class="n">BuildConfig</span><span class="p">,</span>
|
||
<span class="s2">"decoding_config"</span><span class="p">:</span> <span class="n">DecodingConfig</span><span class="p">,</span>
|
||
<span class="s2">"enable_build_cache"</span><span class="p">:</span> <span class="n">BuildCacheConfig</span><span class="p">,</span>
|
||
<span class="s2">"speculative_config"</span><span class="p">:</span> <span class="n">DecodingBaseConfig</span><span class="p">,</span>
|
||
<span class="s2">"lora_config"</span><span class="p">:</span> <span class="n">LoraConfig</span><span class="p">,</span>
|
||
<span class="s2">"moe_config"</span><span class="p">:</span> <span class="n">MoeConfig</span><span class="p">,</span>
|
||
<span class="s2">"attention_dp_config"</span><span class="p">:</span> <span class="n">AttentionDpConfig</span><span class="p">,</span>
|
||
<span class="s2">"sparse_attention_config"</span><span class="p">:</span> <span class="n">BaseSparseAttentionConfig</span><span class="p">,</span>
|
||
<span class="p">}</span>
|
||
<span class="k">for</span> <span class="n">field_name</span><span class="p">,</span> <span class="n">field_type</span> <span class="ow">in</span> <span class="n">field_mapping</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
|
||
<span class="k">if</span> <span class="n">field_name</span> <span class="ow">in</span> <span class="n">llm_args_dict</span><span class="p">:</span>
|
||
<span class="c1"># Some fields need to be converted manually.</span>
|
||
<span class="k">if</span> <span class="n">field_name</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"speculative_config"</span><span class="p">,</span> <span class="s2">"sparse_attention_config"</span><span class="p">]:</span>
|
||
<span class="n">llm_args_dict</span><span class="p">[</span><span class="n">field_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">field_type</span><span class="o">.</span><span class="n">from_dict</span><span class="p">(</span>
|
||
<span class="n">llm_args_dict</span><span class="p">[</span><span class="n">field_name</span><span class="p">])</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">llm_args_dict</span><span class="p">[</span><span class="n">field_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">field_type</span><span class="p">(</span>
|
||
<span class="o">**</span><span class="n">llm_args_dict</span><span class="p">[</span><span class="n">field_name</span><span class="p">])</span>
|
||
<span class="n">extra_llm_str</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"because it's specified in </span><span class="si">{</span><span class="n">extra_llm_api_options</span><span class="si">}</span><span class="s2">"</span> <span class="k">if</span> <span class="n">extra_llm_api_options</span> <span class="k">else</span> <span class="s2">""</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Overriding </span><span class="si">{</span><span class="n">field_name</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">extra_llm_str</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
||
|
||
<span class="n">llm_args</span> <span class="o">=</span> <span class="n">llm_args</span> <span class="o">|</span> <span class="n">llm_args_dict</span>
|
||
|
||
<span class="c1"># For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch</span>
|
||
<span class="c1"># backend, overwriting the knobs there since build_config always has the highest priority</span>
|
||
<span class="k">if</span> <span class="s2">"build_config"</span> <span class="ow">in</span> <span class="n">llm_args</span><span class="p">:</span>
|
||
<span class="c1"># Ensure build_config is a BuildConfig object, not a dict</span>
|
||
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">llm_args</span><span class="p">[</span><span class="s2">"build_config"</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span>
|
||
<span class="n">llm_args</span><span class="p">[</span><span class="s2">"build_config"</span><span class="p">]</span> <span class="o">=</span> <span class="n">BuildConfig</span><span class="p">(</span><span class="o">**</span><span class="n">llm_args</span><span class="p">[</span><span class="s2">"build_config"</span><span class="p">])</span>
|
||
|
||
<span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="p">[</span>
|
||
<span class="s2">"max_batch_size"</span><span class="p">,</span>
|
||
<span class="s2">"max_num_tokens"</span><span class="p">,</span>
|
||
<span class="s2">"max_beam_width"</span><span class="p">,</span>
|
||
<span class="s2">"max_seq_len"</span><span class="p">,</span>
|
||
<span class="p">]:</span>
|
||
<span class="k">if</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">llm_args_dict</span><span class="p">:</span>
|
||
<span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Overriding </span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2"> from build_config to </span><span class="si">{</span><span class="n">llm_args_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="nb">setattr</span><span class="p">(</span><span class="n">llm_args</span><span class="p">[</span><span class="s2">"build_config"</span><span class="p">],</span> <span class="n">key</span><span class="p">,</span> <span class="n">llm_args_dict</span><span class="p">[</span><span class="n">key</span><span class="p">])</span>
|
||
|
||
<span class="k">return</span> <span class="n">llm_args</span>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">update_llm_args_with_extra_options</span><span class="p">(</span><span class="n">llm_args</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span>
|
||
<span class="n">extra_llm_api_options</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="n">extra_llm_api_options</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">extra_llm_api_options</span><span class="p">,</span> <span class="s1">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||
<span class="n">llm_args_dict</span> <span class="o">=</span> <span class="n">yaml</span><span class="o">.</span><span class="n">safe_load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||
<span class="n">llm_args</span> <span class="o">=</span> <span class="n">update_llm_args_with_extra_dict</span><span class="p">(</span><span class="n">llm_args</span><span class="p">,</span> <span class="n">llm_args_dict</span><span class="p">,</span>
|
||
<span class="n">extra_llm_api_options</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">llm_args</span>
|
||
|
||
|
||
<span class="k">def</span><span class="w"> </span><span class="nf">get_model_format</span><span class="p">(</span><span class="n">model_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
|
||
<span class="n">trust_remote_code</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-></span> <span class="n">_ModelFormatKind</span><span class="p">:</span>
|
||
<span class="w"> </span><span class="sd">''' Get the format of the model. '''</span>
|
||
<span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'config.json'</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Failed to infer model format because no config.json exists in </span><span class="si">{</span><span class="n">model_dir</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
|
||
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'config.json'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
|
||
<span class="n">config</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
|
||
|
||
<span class="k">try</span><span class="p">:</span>
|
||
<span class="k">if</span> <span class="s1">'pretrained_config'</span> <span class="ow">in</span> <span class="n">config</span> <span class="ow">and</span> <span class="s1">'build_config'</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
|
||
<span class="n">model_format</span> <span class="o">=</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">TLLM_ENGINE</span>
|
||
<span class="n">EngineConfig</span><span class="o">.</span><span class="n">from_json_file</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span> <span class="o">/</span> <span class="s1">'config.json'</span><span class="p">)</span>
|
||
<span class="k">elif</span> <span class="s1">'architecture'</span> <span class="ow">in</span> <span class="n">config</span> <span class="ow">and</span> <span class="s1">'dtype'</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
|
||
<span class="n">model_format</span> <span class="o">=</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">TLLM_CKPT</span>
|
||
<span class="n">PretrainedConfig</span><span class="o">.</span><span class="n">from_checkpoint</span><span class="p">(</span><span class="n">model_dir</span><span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="n">model_format</span> <span class="o">=</span> <span class="n">_ModelFormatKind</span><span class="o">.</span><span class="n">HF</span>
|
||
<span class="n">AutoConfig</span><span class="o">.</span><span class="n">from_hugging_face</span><span class="p">(</span><span class="n">model_dir</span><span class="p">,</span>
|
||
<span class="n">trust_remote_code</span><span class="o">=</span><span class="n">trust_remote_code</span><span class="p">)</span>
|
||
<span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
||
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
|
||
<span class="sa">f</span><span class="s2">"Inferred model format </span><span class="si">{</span><span class="n">model_format</span><span class="si">}</span><span class="s2">, but failed to load config.json: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">"</span>
|
||
<span class="p">)</span>
|
||
<span class="k">else</span><span class="p">:</span>
|
||
<span class="k">return</span> <span class="n">model_format</span>
|
||
|
||
|
||
<span class="n">LlmArgs</span> <span class="o">=</span> <span class="n">TorchLlmArgs</span>
|
||
|
||
<span class="n">TRT_LLMARGS_EXPLICIT_DOCSTRING</span> <span class="o">=</span> <span class="n">generate_api_docs_as_docstring</span><span class="p">(</span><span class="n">TrtLlmArgs</span><span class="p">,</span>
|
||
<span class="n">indent</span><span class="o">=</span><span class="s1">' '</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span>
|
||
<span class="n">TORCH_LLMARGS_EXPLICIT_DOCSTRING</span> <span class="o">=</span> <span class="n">generate_api_docs_as_docstring</span><span class="p">(</span><span class="n">TorchLlmArgs</span><span class="p">,</span>
|
||
<span class="n">indent</span><span class="o">=</span><span class="s1">' '</span> <span class="o">*</span>
|
||
<span class="mi">4</span><span class="p">)</span>
|
||
</pre></div>
|
||
|
||
</article>
|
||
|
||
|
||
|
||
|
||
|
||
<footer class="prev-next-footer d-print-none">
|
||
|
||
<div class="prev-next-area">
|
||
</div>
|
||
</footer>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
<div class="bd-sidebar-secondary"></div>
|
||
|
||
|
||
|
||
|
||
|
||
</div>
|
||
<footer class="bd-footer-content">
|
||
|
||
</footer>
|
||
|
||
</main>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
<!-- Scripts loaded after <body> so the DOM is not blocked -->
|
||
<script defer src="../../../_static/scripts/bootstrap.js?digest=8878045cc6db502f8baf"></script>
|
||
<script defer src="../../../_static/scripts/pydata-sphinx-theme.js?digest=8878045cc6db502f8baf"></script>
|
||
|
||
|
||
<footer class="bd-footer">
|
||
<div class="bd-footer__inner bd-page-width">
|
||
|
||
<div class="footer-items__start">
|
||
|
||
<div class="footer-item">
|
||
<a class="footer-brand logo" href="https://www.nvidia.com">
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-blk-for-screen.svg" class="logo__image only-light" alt="NVIDIA"/>
|
||
<img src="../../../_static/nvidia-logo-horiz-rgb-1c-wht-for-screen.svg" class="logo__image only-dark" alt="NVIDIA"/>
|
||
</a></div>
|
||
|
||
<div class="footer-item">
|
||
|
||
<div class="footer-links">
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/">Privacy Policy</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/">Your Privacy Choices</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/">Terms of Service</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/">Accessibility</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/">Corporate Policies</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/product-security/">Product Security</a>
|
||
|
|
||
|
||
|
||
|
||
<a class="external" href="https://www.nvidia.com/en-us/contact/">Contact</a>
|
||
|
||
|
||
|
||
</div>
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
|
||
|
||
|
||
|
||
<p class="copyright">
|
||
|
||
Copyright © 2025, NVidia.
|
||
<br/>
|
||
|
||
</p>
|
||
</div>
|
||
|
||
<div class="footer-item">
|
||
<div class="extra_footer">
|
||
|
||
<p>Last updated on November 23, 2025.</p>
|
||
|
||
<p>This page is generated by TensorRT-LLM commit <a href="https://github.com/NVIDIA/TensorRT-LLM/tree/a761585">a761585</a>.</p>
|
||
|
||
</div></div>
|
||
|
||
</div>
|
||
|
||
|
||
|
||
</div>
|
||
|
||
</footer>
|
||
</body>
|
||
</html> |