TensorRT-LLMs/reference/memory.html
2024-12-04 14:25:18 +08:00

353 lines
40 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html class="writer-html5" lang="en" data-content_root="../">
<head>
<meta charset="utf-8" /><meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Memory Usage of TensorRT-LLM &mdash; tensorrt_llm documentation</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css?v=80d5e7a1" />
<link rel="stylesheet" type="text/css" href="../_static/css/theme.css?v=e59714d7" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css?v=76b2166b" />
<script src="../_static/jquery.js?v=5d32c60e"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c"></script>
<script src="../_static/documentation_options.js?v=5929fcd5"></script>
<script src="../_static/doctools.js?v=888ff710"></script>
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
<script src="../_static/clipboard.min.js?v=a7894cd8"></script>
<script src="../_static/copybutton.js?v=65e89d2a"></script>
<script src="../_static/js/theme.js"></script>
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token" href="../blogs/H100vsA100.html" />
<link rel="prev" title="Numerical Precision" href="precision.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="../index.html" class="icon icon-home">
tensorrt_llm
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" aria-label="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../quick-start-guide.html">Quick Start Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../key-features.html">Key Features</a></li>
<li class="toctree-l1"><a class="reference internal" href="../release-notes.html">Release Notes</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Installation</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../installation/linux.html">Installing on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-linux.html">Building from Source Code on Linux</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/windows.html">Installing on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/build-from-source-windows.html">Building from Source Code on Windows</a></li>
<li class="toctree-l1"><a class="reference internal" href="../installation/grace-hopper.html">Installing on Grace Hopper</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../llm-api/index.html">API Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api/reference.html">API Reference</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">LLM API Examples</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/index.html">LLM Examples Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/customization.html">Common Customizations</a></li>
<li class="toctree-l1"><a class="reference internal" href="../llm-api-examples/llm_api_examples.html">Examples</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Model Definition API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.layers.html">Layers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.functional.html">Functionals</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.models.html">Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.plugin.html">Plugin</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../python-api/tensorrt_llm.runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">C++ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/executor.html">Executor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../_cpp_gen/runtime.html">Runtime</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Command-Line Reference</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-build.html">trtllm-build</a></li>
<li class="toctree-l1"><a class="reference internal" href="../commands/trtllm-serve.html">trtllm-serve</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Architecture</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../architecture/overview.html">TensorRT-LLM Architecture</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html">Model Definition</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#compilation">Compilation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#runtime">Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/core-concepts.html#multi-gpu-and-multi-node-support">Multi-GPU and Multi-Node Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/checkpoint.html">TensorRT-LLM Checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/workflow.html">TensorRT-LLM Build Workflow</a></li>
<li class="toctree-l1"><a class="reference internal" href="../architecture/add-model.html">Adding a Model</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Advanced</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-attention.html">Multi-Head, Multi-Query, and Group-Query Attention</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/gpt-runtime.html">C++ GPT Runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/executor.html">Executor API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/graph-rewriting.html">Graph Rewriting Module</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html">Inference Request</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/inference-request.html#responses">Responses</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/lora.html">Run gpt-2b + LoRA using GptManager / cpp runtime</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/expert-parallelism.html">Expert Parallelism in TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/kv-cache-reuse.html">KV cache reuse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../advanced/speculative-decoding.html">Speculative Sampling</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Performance</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-overview.html">Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-benchmarking.html">Benchmarking</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-best-practices.html">Best Practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="../performance/perf-analysis.html">Performance Analysis</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="troubleshooting.html">Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="support-matrix.html">Support Matrix</a></li>
<li class="toctree-l1"><a class="reference internal" href="precision.html">Numerical Precision</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Memory Usage of TensorRT-LLM</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#understand-inference-time-gpu-memory-usage">Understand inference time GPU memory usage</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#weights-size">1. Weights size</a></li>
<li class="toctree-l3"><a class="reference internal" href="#activation-size">2. Activation size</a></li>
<li class="toctree-l3"><a class="reference internal" href="#i-o-tensors">3. I/O tensors</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#runtime-and-decoder-buffers-except-kv-cache-tensor">3.1 Runtime and decoder buffers except KV cache tensor</a></li>
<li class="toctree-l4"><a class="reference internal" href="#kv-cache-tensor">3.2 KV cache tensor</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="#memory-pool">Memory pool</a></li>
<li class="toctree-l2"><a class="reference internal" href="#known-issues">Known Issues</a></li>
<li class="toctree-l2"><a class="reference internal" href="#faq">FAQ</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Blogs</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../blogs/H100vsA100.html">H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/H200launch.html">H200 achieves nearly 12,000 tokens/sec on Llama2-13B with TensorRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/Falcon180B-H200.html">Falcon-180B on a single H200 GPU with INT4 AWQ, and 6.7x faster Llama-70B over A100</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/quantization-in-TRT-LLM.html">Speed up inference with SOTA quantization techniques in TRT-LLM</a></li>
<li class="toctree-l1"><a class="reference internal" href="../blogs/XQA-kernel.html">New XQA-kernel provides 2.4x more Llama-70B throughput within the same latency budget</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="../index.html">tensorrt_llm</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="../index.html" class="icon icon-home" aria-label="Home"></a></li>
<li class="breadcrumb-item active">Memory Usage of TensorRT-LLM</li>
<li class="wy-breadcrumbs-aside">
<a href="../_sources/reference/memory.md.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="memory-usage-of-tensorrt-llm">
<span id="memory"></span><h1>Memory Usage of TensorRT-LLM<a class="headerlink" href="#memory-usage-of-tensorrt-llm" title="Link to this heading"></a></h1>
<p>This document summarizes the memory usage of TensorRT-LLM, and addresses common issues and questions reported by users.</p>
<section id="understand-inference-time-gpu-memory-usage">
<h2>Understand inference time GPU memory usage<a class="headerlink" href="#understand-inference-time-gpu-memory-usage" title="Link to this heading"></a></h2>
<p>At inference time, there are 3 major contributors to GPU memory usage for a given TRT engine generated from a TensorRT-LLM model: weights, internal activation tensors, and I/O tensors. For I/O tensors, the major memory footprint comes from the KV cache tensor.</p>
<section id="weights-size">
<h3>1. Weights size<a class="headerlink" href="#weights-size" title="Link to this heading"></a></h3>
<p>Weights size is fixed depending on the model size, the chosen precision of the weights and the parallelization strategy.
Using lower precision like INT8 or FP8 can reduce the weights size.
When tensor parallelism or pipeline parallelism is used, each rank stores only some portion of the weights.
For example, each rank typically uses just 1/8 of the model weights when using 8-way tensor parallelism or 8-stages pipeline parallelism.</p>
</section>
<section id="activation-size">
<h3>2. Activation size<a class="headerlink" href="#activation-size" title="Link to this heading"></a></h3>
<p>TensorRT can optimize the memory usage by reusing memory for different tensors based on live analysis and tensor size. To avoid out of memory errors at runtime and to reduce the runtime cost of switching optimization profiles and changing shapes, <strong>TensorRT pre-computes the activation tensors memory requirement at build time</strong>. The memory requirement is computed based on an optimized TensorRT graph, one profiles memory usage is computed by using the max tensor shape, and the memory requirement of one engine is computed by the maximum size between different profiles. There are external and internal factors that can affect the activation size returned by TensorRT, such as the network structure, kernel fusion, operation scheduling, etc.</p>
<p>Once the TensorRT engine is built, the activation memory size of that engine <strong>cannot be changed</strong>, and can be queried by the API <code class="docutils literal notranslate"><span class="pre">trt.ICudaEngine.device_memory_size_v2</span></code>.</p>
<p>Practically, for a given model, specified precision and parallelization strategy, one can tune the activation memory usage by adjusting the max batch size, max input length, max beam width, max number of tokens, padding removal on/off flag, context FMHA on/off flag.
Here some explanations on how these values affect the memory:</p>
<ol class="arabic">
<li><p>Reduce build time max number of input tokens (<code class="docutils literal notranslate"><span class="pre">max_num_tokens</span></code>)</p>
<p>Most of the tensors inside a transformer network have a linear relationship with number of input tokens, so activation size will be close to <code class="docutils literal notranslate"><span class="pre">max</span> <span class="pre">number</span> <span class="pre">of</span> <span class="pre">input</span> <span class="pre">tokens</span> <span class="pre">*</span> <span class="pre">some</span> <span class="pre">constant</span> <span class="pre">factor</span></code>, the constant factor depends on the network structure and TRT internal optimization. The max number of input tokens is derived from build time arguments, one can change the parameters provided to the <code class="docutils literal notranslate"><span class="pre">prepare_inputs</span></code> function, like <code class="docutils literal notranslate"><span class="pre">PretrainedModel.prepare_inputs</span></code> to affect the memory usage, or one can change the command line options of the <code class="docutils literal notranslate"><span class="pre">trtllm-build</span></code> command used in the examples.</p>
<p>When using the <a class="reference internal" href="../advanced/gpt-attention.html#padded-and-packed-tensors"><span class="std std-ref">packed tensors</span></a> format and <code class="docutils literal notranslate"><span class="pre">max_num_tokens</span></code> is specified, reducing its value will also reduce activation memory size.</p>
<p>When using the <a class="reference internal" href="../advanced/gpt-attention.html#padded-and-packed-tensors"><span class="std std-ref">padded tensors</span></a> format, the max number of input tokens equals to <code class="docutils literal notranslate"><span class="pre">max_batch_size*max_input_len</span></code>, so reducing <code class="docutils literal notranslate"><span class="pre">max_batch_size</span></code> and <code class="docutils literal notranslate"><span class="pre">max_input_len</span></code> can almost linearly reduce the activation memory size.</p>
<p>The packed tensors format is recommended, because it saves both memory and compute.</p>
<p>The beam width will be folded into the batch size dimension when passing the tensors range into TensorRT, so reducing <code class="docutils literal notranslate"><span class="pre">max_beam_width</span></code> can also reduce the memory usage.</p>
</li>
<li><p>Turn on context FMHA</p>
<p>When the GPT attention plugin is used, turning on the <code class="docutils literal notranslate"><span class="pre">context_fmha_type</span></code> of the plugin will reduce the memory footprint significantly. See the <a class="reference internal" href="../advanced/gpt-attention.html#context-phase"><span class="std std-ref">Context Phase</span></a> for details. When the <code class="docutils literal notranslate"><span class="pre">context_fmha_type</span></code> is set to disabled, a workspace size of the plugin will quadratically depend on the sequence length.</p>
</li>
<li><p>Tensor parallelism and pipeline parallelism</p>
<p>TensorRT will reuse memory between layers as much as possible, for a typical example, given <em>N</em> decoder blocks in one transformer network, TRT will not allocate <em>N</em> copies of the activation memory for each block, since the memory of tensors in the 1st block can be released after the execution, memory can be reused for later blocks, only 1 blocks memory is needed.</p>
<p>When using tensor parallelism, some tensors are split into smaller chunks and each rank only holds one chunk of the tensor, the activation memory size of each rank will be smaller than when executing the network on a single GPU. When using pipeline parallelism, each rank executes several decoder blocks, and all the tensors are full-size tensors, so the activation memory size is equal to 1 blocks memory size. Thus tensor parallelism normally has higher memory efficiency than pipeline parallelism when all other parameters are the same.</p>
</li>
</ol>
</section>
<section id="i-o-tensors">
<h3>3. I/O tensors<a class="headerlink" href="#i-o-tensors" title="Link to this heading"></a></h3>
<section id="runtime-and-decoder-buffers-except-kv-cache-tensor">
<h4>3.1 Runtime and decoder buffers except KV cache tensor<a class="headerlink" href="#runtime-and-decoder-buffers-except-kv-cache-tensor" title="Link to this heading"></a></h4>
<section id="c-runtime">
<h5>C++ runtime<a class="headerlink" href="#c-runtime" title="Link to this heading"></a></h5>
<p>Before KV cache blocks are allocated, some amount of GPU memory are pre-allocated by C++ runtime for storing I/O tensors of TensorRT engine and the decoupled dynamic decoder, its allocated based on runtime max_batch_size and max_seq_len so that OOM can be avoided when there are indeed that amount of requests scheduled.</p>
</section>
</section>
<section id="kv-cache-tensor">
<h4>3.2 KV cache tensor<a class="headerlink" href="#kv-cache-tensor" title="Link to this heading"></a></h4>
<section id="id1">
<h5>C++ runtime<a class="headerlink" href="#id1" title="Link to this heading"></a></h5>
<ul>
<li><p>When paged KV cache is enabled</p>
<p>TensorRT-LLM runtime pre-allocates KV cache tensors during initialization for a configured number of blocks and distributes them at runtime.</p>
<p>KV cache tensors are allocated based on the <code class="docutils literal notranslate"><span class="pre">KVCacheConfig</span></code> object when creating the <code class="docutils literal notranslate"><span class="pre">Executor</span></code>. If neither <code class="docutils literal notranslate"><span class="pre">maxTokens</span></code> nor <code class="docutils literal notranslate"><span class="pre">freeGpuMemoryFraction</span></code> is specified, KV cache will by default allocate 90% of the remaining free GPU memory. When either <code class="docutils literal notranslate"><span class="pre">maxTokens</span></code> or <code class="docutils literal notranslate"><span class="pre">freeGpuMemoryFraction</span></code> is specified, the specified value will be used to compute the KV cache memory size. And if both are specified, firstly the <code class="docutils literal notranslate"><span class="pre">freeGpuMemoryFraction</span></code> is used to compute the number of tokens in KV cache, and then the minimum between this computed number of tokens and <code class="docutils literal notranslate"><span class="pre">maxTokens</span></code> is used.</p>
<p>In in-flight batching the scheduler can automatically schedule requests as long as enough KV cache space is available (exact behavior depends on the scheduler policy).</p>
<p>If paged KV cache is used in <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> (already deprecated) without in-flight batching, TensorRT-LLM may report OOM errors with message “Cant allocate new blocks. No free blocks left”, if the paged KV cache is not large enough for the whole batch.</p>
</li>
<li><p>When paged KV cache is disabled (Not recommended and only allowed for deprecated <code class="docutils literal notranslate"><span class="pre">GptSession</span></code>)</p>
<p>C++ runtime allocates the KV cache tensors for each layer with shape <code class="docutils literal notranslate"><span class="pre">[batch</span> <span class="pre">size,</span> <span class="pre">2,</span> <span class="pre">heads,</span>&#160; <span class="pre">max</span> <span class="pre">seq</span> <span class="pre">length,</span> <span class="pre">hidden</span> <span class="pre">dimension</span> <span class="pre">per</span> <span class="pre">head]</span></code>, where <code class="docutils literal notranslate"><span class="pre">max</span> <span class="pre">seq</span> <span class="pre">length</span></code> is specified by <code class="docutils literal notranslate"><span class="pre">GptSession::Config::maxSequenceLength</span></code> when creating <code class="docutils literal notranslate"><span class="pre">GptSession</span></code>.</p>
</li>
</ul>
</section>
<section id="python-runtime-not-recommended-to-be-used">
<h5>Python runtime (Not recommended to be used)<a class="headerlink" href="#python-runtime-not-recommended-to-be-used" title="Link to this heading"></a></h5>
<p>The Python runtime allocates KV cache tensors based on the parameters of the <code class="docutils literal notranslate"><span class="pre">GenerationSession.setup</span></code> function, the KV cache size is linearly dependent on the <code class="docutils literal notranslate"><span class="pre">batch_size</span></code> and <code class="docutils literal notranslate"><span class="pre">max_context_length+max_new_tokens</span></code>. <strong>Note: This may change in the future, as the Python bindings of the C++ runtime may replace the current python runtime in the future. The Python bindings of C++ runtime behave like C++ runtime.</strong></p>
</section>
</section>
</section>
</section>
<section id="memory-pool">
<h2>Memory pool<a class="headerlink" href="#memory-pool" title="Link to this heading"></a></h2>
<p>TensorRT-LLM C++ runtime is using stream-ordered memory allocator to allocate and free buffers, see <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/runtime/bufferManager.cpp">BufferManager::initMemoryPool</a>, which uses the default memory pool managed by the CUDA driver. When a <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> object is destroyed, memory is returned to the memory pool and can be reused by the next instance of a <code class="docutils literal notranslate"><span class="pre">GptSession</span></code> object. Memory will be released from the pool if it is required for other memory allocations.</p>
<p>However, <code class="docutils literal notranslate"><span class="pre">nvidia-smi</span></code> may still show high memory occupation after memory is returned to the CUDA drivers memory pool. This should not be a concern and is intended behavior. The amount of reserved and free memory in the pool can be inspected by <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/runtime/bufferManager.cpp">BufferManager::memoryPoolReserved())</a> and <a class="reference external" href="https://github.com/NVIDIA/TensorRT-LLM/tree/rel/cpp/tensorrt_llm/runtime/bufferManager.cpp">BufferManager::memoryPoolFree())</a>, respectively.</p>
</section>
<section id="known-issues">
<h2>Known Issues<a class="headerlink" href="#known-issues" title="Link to this heading"></a></h2>
<p>When FP8 GEMM is used, the activation memory might be larger than the theoretical optimized memory size, this will be enhanced in a future release.</p>
</section>
<section id="faq">
<h2>FAQ<a class="headerlink" href="#faq" title="Link to this heading"></a></h2>
<ol class="arabic">
<li><p>How to debug the memory usage of TensorRT-LLM?</p>
<p>When the <code class="docutils literal notranslate"><span class="pre">info</span></code> logging level is used, TensorRT and TensorRT-LLM will print messages about memory usage details. Here is part of a log example with <code class="docutils literal notranslate"><span class="pre">info</span></code> logging level at runtime:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Loaded</span> <span class="n">engine</span> <span class="n">size</span><span class="p">:</span> <span class="mi">6695</span> <span class="n">MiB</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MemUsageChange</span><span class="p">]</span> <span class="n">Allocated</span> <span class="mf">1134.01</span> <span class="n">MiB</span> <span class="k">for</span> <span class="n">execution</span> <span class="n">context</span> <span class="n">memory</span><span class="o">.</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MS</span><span class="p">]</span> <span class="n">Running</span> <span class="n">engine</span> <span class="k">with</span> <span class="n">multi</span> <span class="n">stream</span> <span class="n">info</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MS</span><span class="p">]</span> <span class="n">Number</span> <span class="n">of</span> <span class="n">aux</span> <span class="n">streams</span> <span class="ow">is</span> <span class="mi">1</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MS</span><span class="p">]</span> <span class="n">Number</span> <span class="n">of</span> <span class="n">total</span> <span class="n">worker</span> <span class="n">streams</span> <span class="ow">is</span> <span class="mi">2</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MS</span><span class="p">]</span> <span class="n">The</span> <span class="n">main</span> <span class="n">stream</span> <span class="n">provided</span> <span class="n">by</span> <span class="n">execute</span><span class="o">/</span><span class="n">enqueue</span> <span class="n">calls</span> <span class="ow">is</span> <span class="n">the</span> <span class="n">first</span> <span class="n">worker</span> <span class="n">stream</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MemUsageChange</span><span class="p">]</span> <span class="n">TensorRT</span><span class="o">-</span><span class="n">managed</span> <span class="n">allocation</span> <span class="ow">in</span> <span class="n">IExecutionContext</span> <span class="n">creation</span><span class="p">:</span> <span class="n">CPU</span> <span class="o">+</span><span class="mi">0</span><span class="p">,</span> <span class="n">GPU</span> <span class="o">+</span><span class="mi">0</span><span class="p">,</span> <span class="n">now</span><span class="p">:</span> <span class="n">CPU</span> <span class="mi">0</span><span class="p">,</span> <span class="n">GPU</span> <span class="mi">6678</span> <span class="p">(</span><span class="n">MiB</span><span class="p">)</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MemUsageChange</span><span class="p">]</span> <span class="n">Allocated</span> <span class="mf">43.29</span> <span class="n">MB</span> <span class="n">GPU</span> <span class="n">memory</span> <span class="k">for</span> <span class="n">runtime</span> <span class="n">buffers</span><span class="o">.</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MemUsageChange</span><span class="p">]</span> <span class="n">Allocated</span> <span class="mf">180.30</span> <span class="n">MB</span> <span class="n">GPU</span> <span class="n">memory</span> <span class="k">for</span> <span class="n">decoder</span><span class="o">.</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Memory</span> <span class="n">usage</span> <span class="n">when</span> <span class="n">calculating</span> <span class="nb">max</span> <span class="n">tokens</span> <span class="ow">in</span> <span class="n">paged</span> <span class="n">kv</span> <span class="n">cache</span><span class="p">:</span> <span class="n">total</span><span class="p">:</span> <span class="mf">79.10</span> <span class="n">GiB</span><span class="p">,</span> <span class="n">available</span><span class="p">:</span> <span class="mf">70.48</span> <span class="n">GiB</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Number</span> <span class="n">of</span> <span class="n">blocks</span> <span class="ow">in</span> <span class="n">KV</span> <span class="n">cache</span> <span class="n">primary</span> <span class="n">pool</span><span class="p">:</span> <span class="mi">4060</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Number</span> <span class="n">of</span> <span class="n">blocks</span> <span class="ow">in</span> <span class="n">KV</span> <span class="n">cache</span> <span class="n">secondary</span> <span class="n">pool</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="n">onboard</span> <span class="n">blocks</span> <span class="n">to</span> <span class="n">primary</span> <span class="n">memory</span> <span class="n">before</span> <span class="n">reuse</span><span class="p">:</span> <span class="n">true</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Max</span> <span class="n">KV</span> <span class="n">cache</span> <span class="n">pages</span> <span class="n">per</span> <span class="n">sequence</span><span class="p">:</span> <span class="mi">32</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="n">Number</span> <span class="n">of</span> <span class="n">tokens</span> <span class="n">per</span> <span class="n">block</span><span class="p">:</span> <span class="mf">64.</span>
<span class="p">[</span><span class="n">TensorRT</span><span class="o">-</span><span class="n">LLM</span><span class="p">][</span><span class="n">INFO</span><span class="p">]</span> <span class="p">[</span><span class="n">MemUsageChange</span><span class="p">]</span> <span class="n">Allocated</span> <span class="mf">63.44</span> <span class="n">GiB</span> <span class="k">for</span> <span class="nb">max</span> <span class="n">tokens</span> <span class="ow">in</span> <span class="n">paged</span> <span class="n">KV</span> <span class="n">cache</span> <span class="p">(</span><span class="mi">259840</span><span class="p">)</span><span class="o">.</span>
</pre></div>
</div>
<p>You can see that there are several GPU memory allocation started with <code class="docutils literal notranslate"><span class="pre">[MemUsageChange]</span></code> keyword happened at runtime.</p>
<p>The line showing “Total Weights Memory” indicates the weights memory size, and the line “Total Activation Memory” indicates the activation memory size.</p>
<p>Normally the weights memory size is close to the TensorRT engine size, since most of the content in the engine is from weights for LLM networks.</p>
</li>
<li><p>Why is the memory size large even though a small batch size and sequence length are used in the runtime?</p>
<p>As explained above, the activation memory size is computed based on the max tensor shapes at TensorRT engine building time, try to reduce the engine building time parameters like <code class="docutils literal notranslate"><span class="pre">max_num_token</span></code>, see <a class="reference internal" href="#activation-size"><span class="xref myst">Activation size</span></a> for details.</p>
</li>
<li><p>Why can the engine be generated, but the inference will run out of memory (OOM) at runtime?</p>
<p>At engine building time, TensorRT will tune the kernel selection layer by layer, it does not necessarily allocate all the memory required to run the entire engine. If the activation tensors required to run a single layer are small, while the I/O tensor (like KV cache) sizes required to run the engine are large, building will succeed since it may not need to allocate the large I/O tensors, runtime may fail with OOM errors on allocating large IO tensors.</p>
<p>TensorRT-LLM has provided a <code class="docutils literal notranslate"><span class="pre">check_gpt_mem_usage</span></code> utility function to check the upper bound of the memory size given an engine, and the related batch size, I/O sequence length, etc., when the upper boundary check exceeded the GPU physical memory size, warning messages will be printed.</p>
</li>
<li><p>For pipeline parallelism, is build time max batch size the limit of micro batch size?</p>
<p>Yes, in pipeline parallel mode, TensorRT-LLM runtime will split the batch of requests into micro batches, and enqueue these micro batches into TRT engine sequentially.</p>
<p>The <code class="docutils literal notranslate"><span class="pre">max_batch_size</span></code> at build time means that batch size of one engine enqueue call shall be smaller than it. The total batch size before splitting into micro batches can be larger than the build time <code class="docutils literal notranslate"><span class="pre">max_batch_size</span></code>.</p>
<p>For example, if you have 4-stages pipeline parallelism, and intend to run the engine using micro batch size 2 and run 16 micro batches (total batch size 32) in one <code class="docutils literal notranslate"><span class="pre">generate</span></code> call.</p>
<p>You could just set the <code class="docutils literal notranslate"><span class="pre">max_batch_size</span></code> at building time to 2, instead of 32. Setting build time <code class="docutils literal notranslate"><span class="pre">max_batch_size</span></code> 32 will occupy almost 16x more activation memory.</p>
</li>
</ol>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="precision.html" class="btn btn-neutral float-left" title="Numerical Precision" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="../blogs/H100vsA100.html" class="btn btn-neutral float-right" title="H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<jinja2.runtime.BlockReference object at 0x7f9468de5250>
<div class="footer">
<p>
Copyright © 2024 NVIDIA Corporation
</p>
<p>
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank" rel="noopener"
data-cms-ai="0">Privacy Policy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank" rel="noopener"
data-cms-ai="0">Manage My Privacy</a> |
<a class="Link" href="https://www.nvidia.com/en-us/preferences/start/" target="_blank" rel="noopener"
data-cms-ai="0">Do Not Sell or Share My Data</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank"
rel="noopener" data-cms-ai="0">Terms of Service</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank" rel="noopener"
data-cms-ai="0">Accessibility</a> |
<a class="Link" href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank"
rel="noopener" data-cms-ai="0">Corporate Policies</a> |
<a class="Link" href="https://www.nvidia.com/en-us/product-security/" target="_blank" rel="noopener"
data-cms-ai="0">Product Security</a> |
<a class="Link" href="https://www.nvidia.com/en-us/contact/" target="_blank" rel="noopener"
data-cms-ai="0">Contact</a>
</p>
</div>
</div>
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>