mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
212 lines
9.0 KiB
Python
Executable File
212 lines
9.0 KiB
Python
Executable File
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import str_dtype_to_torch
|
|
from tensorrt_llm.llmapi.llm_utils import QuantConfig
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
|
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
|
|
from .utils import get_device_memory
|
|
import math
|
|
|
|
BYTES_PER_ELEM = {
|
|
QuantAlgo.NO_QUANT: 2.0,
|
|
QuantAlgo.FP8: 1.0,
|
|
QuantAlgo.FP8_BLOCK_SCALES: 1.0,
|
|
QuantAlgo.NVFP4: .5,
|
|
}
|
|
|
|
|
|
def calc_engine_setting(
|
|
model_config: ModelConfig,
|
|
quant_config: QuantConfig,
|
|
tp_size: int,
|
|
pp_size: int,
|
|
target_input_len: int,
|
|
target_output_len: int,
|
|
kv_cache_gpu_mem_fraction: float = 0.95,
|
|
) -> Tuple[int, int]:
|
|
""" Calculate the engine build settings (max batch size and max num tokens)
|
|
for a specific model + parallelism mapping + dataset configuration.
|
|
trtllm-bench sets a slightly optimistic upper bound for max batch size
|
|
and max num tokens to avoid over-allocation of memory in activation,
|
|
runtime, and decoder buffers. In runtime, TRT-LLM relies on its runtime
|
|
tuning features to adjust the runtime max batch size according to
|
|
incoming traffic.
|
|
|
|
Args:
|
|
model_config (ModelConfig): Model specific configurations.
|
|
quant_config (QuantConfig): Quantization specifications.
|
|
tp_size (int): Number of tensor parallel shards.
|
|
pp_size (int): Number of pipeline parallel stages.
|
|
target_input_len (int): Target input length to compile the engine.
|
|
target_output_len (int): Target output length to compile the engine.
|
|
kv_cache_gpu_mem_fraction (float): Fraction of free memory to allocate
|
|
for KV cache.
|
|
|
|
Raises:
|
|
RuntimeError: When the number of GPUs or amount of KV cache is unable to
|
|
support the model.
|
|
|
|
Returns:
|
|
Tuple[int, int]: Tuple containing engine configuration information for
|
|
engine build (max_num_tokens, max_batch_size).
|
|
"""
|
|
byte_per_elem = BYTES_PER_ELEM.get(quant_config.quant_algo, 2)
|
|
byte_per_kv_elem = BYTES_PER_ELEM.get(quant_config.kv_cache_quant_algo, 2)
|
|
|
|
# Each GPU in TP group has at least 1 kv head
|
|
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
|
|
|
|
logger.info(
|
|
f"Number of attention layers: {model_config.num_attention_layers}")
|
|
|
|
gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
|
|
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)
|
|
|
|
# Number of GPU used for this run.
|
|
n_gpus = tp_size * pp_size
|
|
# Total engine size.
|
|
engine_size = model_config.param_count * byte_per_elem / (1024**3)
|
|
total_gpu_memory = get_device_memory() * n_gpus
|
|
# Available memory to allocate KV cache.
|
|
available_memory = total_gpu_memory - engine_size
|
|
logger.info(f"Estimated engine size: {engine_size:.2f} GB")
|
|
logger.info("Estimated total available memory for KV cache: "
|
|
f"{available_memory:.2f} GB")
|
|
|
|
# Calculate max requests in KV cache based on target ISL and OSL.
|
|
target_seq_len = target_input_len + target_output_len
|
|
cache_memory = available_memory * model_config.cache_memory_fraction(
|
|
kv_cache_gpu_mem_fraction)
|
|
|
|
bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT)
|
|
if isinstance(model_config, NemotronHybridConfig):
|
|
mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype
|
|
if mamba_ssm_cache_dtype != "auto":
|
|
if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32:
|
|
bytes_per_elem = 4.0
|
|
|
|
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
|
|
bytes_per_elem, target_seq_len)
|
|
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
|
|
gb_per_extra_cache)
|
|
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
|
|
kv_cache_memory = cache_memory - extra_cache_memory
|
|
kv_cache_max_tokens = kv_cache_memory / gb_per_token
|
|
|
|
logger.info(
|
|
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
|
|
)
|
|
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
|
|
logger.info("Estimated max number of requests in KV cache memory: "
|
|
f"{kv_cache_max_requests:.2f}")
|
|
|
|
# Fine-tune the max batch size and num token setting for performance.
|
|
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
|
|
max_batch_size, max_num_tokens = finetune_setting(
|
|
kv_cache_max_requests,
|
|
target_input_len,
|
|
target_output_len,
|
|
pp_size,
|
|
disable_optimistic_tuning=isinstance(model_config,
|
|
NemotronHybridConfig))
|
|
|
|
# Functional and performance
|
|
if total_gpu_memory < engine_size:
|
|
raise RuntimeError(
|
|
f"The model requires at least: {engine_size:.2f} GB, "
|
|
f"the total GPU memory of {total_gpu_memory:.2f} is insufficient.\n"
|
|
"----------------------------------------------------------\n"
|
|
f"Estimation based on the following:\n"
|
|
"----------------------------------------------------------\n"
|
|
f"Bytes per Element: {byte_per_elem}\n"
|
|
f"Bytes per KV Element: {byte_per_kv_elem}\n"
|
|
f"Number of GPUs: {n_gpus}\n"
|
|
f"Model Number of KV Heads: {model_config.num_key_value_heads}\n"
|
|
f"Adjusted Number of KV Heads: {adjusted_num_kv_heads}\n"
|
|
f"Head Size: {model_config.head_size}\n"
|
|
f"Number of Hidden Layers: {model_config.num_hidden_layers}\n"
|
|
f"Number of Pipeline Stages: {pp_size}\n"
|
|
f"Number of Tensor Parallel Shards: {tp_size}\n"
|
|
f"Number of Pipeline Parallel Stages: {pp_size}\n"
|
|
f"KV Cache GPU Memory Fraction: {kv_cache_gpu_mem_fraction}\n"
|
|
"----------------------------------------------------------\n")
|
|
if kv_cache_max_requests < 1:
|
|
raise RuntimeError("The amount of KV cache memory is insufficient to "
|
|
"run this model. Please try with more GPUs.")
|
|
if cache_memory / n_gpus < 10.0:
|
|
logger.warning(
|
|
f"The KV cache memory per GPU is less than 10 GB. "
|
|
"Performance may be undesirable. Please consider using a different "
|
|
"mapping or more GPUs.")
|
|
if kv_cache_max_requests < 32:
|
|
logger.warning(
|
|
f"The maximum number of requests in the KV cache is too "
|
|
"small. Performance may be undesirable. Please consider using more "
|
|
"GPUs or a different mapping to process more concurrent requests.")
|
|
|
|
return max_batch_size, max_num_tokens
|
|
|
|
|
|
def finetune_setting(
|
|
kv_cache_max_requests: float,
|
|
input_len: int,
|
|
output_len: int,
|
|
pp_size: int,
|
|
disable_optimistic_tuning: bool = False,
|
|
) -> Tuple[int, int]:
|
|
""" Calculate and fine-tune the engine build settings (max batch size and
|
|
max num tokens). Both max batch size and max num tokens are fine-tuned
|
|
to be slightly optimistic.
|
|
|
|
Args:
|
|
kv_cache_max_requests (float): Max number of requests that can fits in
|
|
the available KV cache memory.
|
|
input_len (int): Input sequence length to compile the engine.
|
|
output_len (int): Output sequence length to compile the engine.
|
|
pp_size (int): Number of pipeline parallel stages.
|
|
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.
|
|
|
|
Returns:
|
|
Tuple[int, int]: Tuple containing fine-tuned values for engine
|
|
configuration information.
|
|
"""
|
|
# Cap total batch size to avoid decoder buffer size becoming too large.
|
|
raw_bs = min(kv_cache_max_requests, 4096) / pp_size
|
|
# Cap num tokens to avoid TRT activation buffer becoming too large.
|
|
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)
|
|
|
|
# Fine-tune the max batch size.
|
|
if disable_optimistic_tuning:
|
|
max_bs = 2 * math.floor(raw_bs / 2)
|
|
else:
|
|
# Set min BS to be 64.
|
|
if raw_bs < 256:
|
|
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
|
|
elif raw_bs < 1024:
|
|
max_bs = 128 * math.ceil(raw_bs / 128)
|
|
else:
|
|
max_bs = 256 * math.ceil(raw_bs / 256)
|
|
|
|
# Fine-tune the max num tokens.
|
|
# Set min to 2048 to ensure Ctx/Gen overlap efficiency
|
|
if raw_token < 4096:
|
|
max_token = max(2048, 256 * math.ceil(raw_token / 256))
|
|
elif raw_token < 8192:
|
|
max_token = 512 * math.ceil(raw_token / 512)
|
|
else:
|
|
max_token = 1024 * math.ceil(raw_token / 1024)
|
|
|
|
logger.debug(f"Estimated max batch size (before fine-tune): "
|
|
f"{kv_cache_max_requests / pp_size:.2f}")
|
|
logger.debug(
|
|
f"Estimated max num tokens (before fine-tune): "
|
|
f"{kv_cache_max_requests / pp_size * (1 + input_len / output_len) :.2f}"
|
|
)
|
|
logger.info(f"Estimated max batch size (after fine-tune): {max_bs}")
|
|
logger.info(f"Estimated max num tokens (after fine-tune): {max_token}")
|
|
|
|
return max_bs, max_token
|