TensorRT-LLMs/tensorrt_llm/bench/build/tuning.py
shaharmor98 14b36e07d7
[TRTLLM-6174][feat] Enable FP32 mamba ssm cache (#6574)
Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
2025-08-10 16:27:51 -04:00

212 lines
9.0 KiB
Python
Executable File

from typing import Tuple
import torch
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
from .utils import get_device_memory
import math
BYTES_PER_ELEM = {
QuantAlgo.NO_QUANT: 2.0,
QuantAlgo.FP8: 1.0,
QuantAlgo.FP8_BLOCK_SCALES: 1.0,
QuantAlgo.NVFP4: .5,
}
def calc_engine_setting(
model_config: ModelConfig,
quant_config: QuantConfig,
tp_size: int,
pp_size: int,
target_input_len: int,
target_output_len: int,
kv_cache_gpu_mem_fraction: float = 0.95,
) -> Tuple[int, int]:
""" Calculate the engine build settings (max batch size and max num tokens)
for a specific model + parallelism mapping + dataset configuration.
trtllm-bench sets a slightly optimistic upper bound for max batch size
and max num tokens to avoid over-allocation of memory in activation,
runtime, and decoder buffers. In runtime, TRT-LLM relies on its runtime
tuning features to adjust the runtime max batch size according to
incoming traffic.
Args:
model_config (ModelConfig): Model specific configurations.
quant_config (QuantConfig): Quantization specifications.
tp_size (int): Number of tensor parallel shards.
pp_size (int): Number of pipeline parallel stages.
target_input_len (int): Target input length to compile the engine.
target_output_len (int): Target output length to compile the engine.
kv_cache_gpu_mem_fraction (float): Fraction of free memory to allocate
for KV cache.
Raises:
RuntimeError: When the number of GPUs or amount of KV cache is unable to
support the model.
Returns:
Tuple[int, int]: Tuple containing engine configuration information for
engine build (max_num_tokens, max_batch_size).
"""
byte_per_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2)
byte_per_kv_elem = BYTES_PER_ELEM.get(quant_config.kv_cache_quant_algo, 2)
# Each GPU in TP group has at least 1 kv head
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
logger.info(
f"Number of attention layers: {model_config.num_attention_layers}")
gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)
# Number of GPU used for this run.
n_gpus = tp_size * pp_size
# Total engine size.
engine_size = model_config.param_count * byte_per_elem / (1024**3)
total_gpu_memory = get_device_memory() * n_gpus
# Available memory to allocate KV cache.
available_memory = total_gpu_memory - engine_size
logger.info(f"Estimated engine size: {engine_size:.2f} GB")
logger.info("Estimated total available memory for KV cache: "
f"{available_memory:.2f} GB")
# Calculate max requests in KV cache based on target ISL and OSL.
target_seq_len = target_input_len + target_output_len
cache_memory = available_memory * model_config.cache_memory_fraction(
kv_cache_gpu_mem_fraction)
bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT)
if isinstance(model_config, NemotronHybridConfig):
mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype
if mamba_ssm_cache_dtype != "auto":
if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32:
bytes_per_elem = 4.0
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
bytes_per_elem, target_seq_len)
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
gb_per_extra_cache)
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
kv_cache_memory = cache_memory - extra_cache_memory
kv_cache_max_tokens = kv_cache_memory / gb_per_token
logger.info(
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
)
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
logger.info("Estimated max number of requests in KV cache memory: "
f"{kv_cache_max_requests:.2f}")
# Fine-tune the max batch size and num token setting for performance.
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
max_batch_size, max_num_tokens = finetune_setting(
kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size,
disable_optimistic_tuning=isinstance(model_config,
NemotronHybridConfig))
# Functional and performance
if total_gpu_memory < engine_size:
raise RuntimeError(
f"The model requires at least: {engine_size:.2f} GB, "
f"the total GPU memory of {total_gpu_memory:.2f} is insufficient.\n"
"----------------------------------------------------------\n"
f"Estimation based on the following:\n"
"----------------------------------------------------------\n"
f"Bytes per Element: {byte_per_elem}\n"
f"Bytes per KV Element: {byte_per_kv_elem}\n"
f"Number of GPUs: {n_gpus}\n"
f"Model Number of KV Heads: {model_config.num_key_value_heads}\n"
f"Adjusted Number of KV Heads: {adjusted_num_kv_heads}\n"
f"Head Size: {model_config.head_size}\n"
f"Number of Hidden Layers: {model_config.num_hidden_layers}\n"
f"Number of Pipeline Stages: {pp_size}\n"
f"Number of Tensor Parallel Shards: {tp_size}\n"
f"Number of Pipeline Parallel Stages: {pp_size}\n"
f"KV Cache GPU Memory Fraction: {kv_cache_gpu_mem_fraction}\n"
"----------------------------------------------------------\n")
if kv_cache_max_requests < 1:
raise RuntimeError("The amount of KV cache memory is insufficient to "
"run this model. Please try with more GPUs.")
if cache_memory / n_gpus < 10.0:
logger.warning(
f"The KV cache memory per GPU is less than 10 GB. "
"Performance may be undesirable. Please consider using a different "
"mapping or more GPUs.")
if kv_cache_max_requests < 32:
logger.warning(
f"The maximum number of requests in the KV cache is too "
"small. Performance may be undesirable. Please consider using more "
"GPUs or a different mapping to process more concurrent requests.")
return max_batch_size, max_num_tokens
def finetune_setting(
kv_cache_max_requests: float,
input_len: int,
output_len: int,
pp_size: int,
disable_optimistic_tuning: bool = False,
) -> Tuple[int, int]:
""" Calculate and fine-tune the engine build settings (max batch size and
max num tokens). Both max batch size and max num tokens are fine-tuned
to be slightly optimistic.
Args:
kv_cache_max_requests (float): Max number of requests that can fits in
the available KV cache memory.
input_len (int): Input sequence length to compile the engine.
output_len (int): Output sequence length to compile the engine.
pp_size (int): Number of pipeline parallel stages.
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.
Returns:
Tuple[int, int]: Tuple containing fine-tuned values for engine
configuration information.
"""
# Cap total batch size to avoid decoder buffer size becoming too large.
raw_bs = min(kv_cache_max_requests, 4096) / pp_size
# Cap num tokens to avoid TRT activation buffer becoming too large.
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)
# Fine-tune the max batch size.
if disable_optimistic_tuning:
max_bs = 2 * math.floor(raw_bs / 2)
else:
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
# Fine-tune the max num tokens.
# Set min to 2048 to ensure Ctx/Gen overlap efficiency
if raw_token < 4096:
max_token = max(2048, 256 * math.ceil(raw_token / 256))
elif raw_token < 8192:
max_token = 512 * math.ceil(raw_token / 512)
else:
max_token = 1024 * math.ceil(raw_token / 1024)
logger.debug(f"Estimated max batch size (before fine-tune): "
f"{kv_cache_max_requests / pp_size:.2f}")
logger.debug(
f"Estimated max num tokens (before fine-tune): "
f"{kv_cache_max_requests / pp_size * (1 + input_len / output_len) :.2f}"
)
logger.info(f"Estimated max batch size (after fine-tune): {max_bs}")
logger.info(f"Estimated max num tokens (after fine-tune): {max_token}")
return max_bs, max_token