[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H (#5371)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-07-09 11:30:15 +03:00 committed by GitHub
parent 10e686466e
commit 5aa958a11a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 145 additions and 37 deletions

View File

@ -7,6 +7,7 @@ import torch
import transformers
from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm._utils import torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
@ -298,7 +299,7 @@ class ModelConfig(Generic[TConfig]):
model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
num_layers=self.pretrained_config.num_hidden_layers,
num_attention_layers=self.pretrained_config.num_hidden_layers,
num_attention_layers=self.get_num_attention_layers(),
num_rnn_layers=0,
num_heads=num_heads,
hidden_size=hidden_size,
@ -376,3 +377,9 @@ class ModelConfig(Generic[TConfig]):
] * self.pretrained_config.num_hidden_layers
else:
return None
def get_num_attention_layers(self):
if is_nemotron_hybrid(self.pretrained_config):
return self.pretrained_config.hybrid_override_pattern.count("*")
else:
return self.pretrained_config.num_hidden_layers

View File

@ -82,9 +82,9 @@ class KvCacheCreator:
) * num_key_value_heads // tp_size
# provide at least 1 layer to prevent division by zero cache size
num_hidden_layers = max(
len(mapping.pp_layers(config.num_hidden_layers)), 1)
mem_per_token *= num_hidden_layers * head_dim
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
mem_per_token *= num_attention_layers * head_dim
# K and V
mem_per_token *= kv_factor
return mem_per_token

View File

@ -818,7 +818,7 @@ class MambaCacheManager(BaseResourceManager):
device=device,
dtype=torch.int32)
def prepare_mamba_cache_blocks(self, request_ids: List[int]):
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
for r in request_ids:
# cache hit
@ -834,12 +834,7 @@ class MambaCacheManager(BaseResourceManager):
self.state_indices[:len(state_indices)] = torch.as_tensor(
state_indices, dtype=torch.int32, device=self.ssm_states.device)
def free_mamba_cache_blocks(self, request_id: int):
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)
def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests):
def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [
i.py_request_id for i in scheduled_batch.context_requests
]
@ -847,10 +842,13 @@ class MambaCacheManager(BaseResourceManager):
i.py_request_id for i in scheduled_batch.generation_requests
]
request_ids = context_ids + generation_ids
self.prepare_mamba_cache_blocks(request_ids)
self._prepare_mamba_cache_blocks(request_ids)
def free_mamba_resources(self, request: LlmRequest):
self.free_mamba_cache_blocks(request.py_request_id)
def free_resources(self, request: LlmRequest):
request_id = request.py_request_id
if request_id in self.mamba_cache_index:
block = self.mamba_cache_index.pop(request_id)
self.mamba_cache_free_blocks.append(block)
def get_state_indices(self) -> torch.Tensor:
return self.state_indices
@ -863,6 +861,13 @@ class MambaCacheManager(BaseResourceManager):
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.ssm_states[layer_offset]
def shutdown(self):
# release tensor memory, keeping python references as tensors
self.conv_states = torch.tensor([])
self.ssm_states = torch.tensor([])
self.state_indices = torch.tensor([])
torch.cuda.empty_cache()
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
@ -933,12 +938,16 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
)
def prepare_resources(self, scheduled_batch: ScheduledRequests):
self.prepare_mamba_resources(scheduled_batch)
super().prepare_resources(scheduled_batch)
MambaCacheManager.prepare_resources(self, scheduled_batch)
KVCacheManager.prepare_resources(self, scheduled_batch)
def free_resources(self, request: LlmRequest):
self.free_mamba_resources(request)
super().free_resources(request)
MambaCacheManager.free_resources(self, request)
KVCacheManager.free_resources(self, request)
def shutdown(self):
MambaCacheManager.shutdown(self)
KVCacheManager.shutdown(self)
class SlotManager:

View File

@ -130,6 +130,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
params.get("pp"),
dataset_metadata.avg_isl,
dataset_metadata.avg_osl,
params.get("kv_cache_free_gpu_mem_fraction"),
)
logger.info(

View File

@ -1,10 +1,12 @@
from __future__ import annotations
from transformers import AutoConfig
from pathlib import Path
from typing import Tuple, get_args
import click
from click_option_group import AllOptionGroup, optgroup
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
@ -13,7 +15,7 @@ from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.bench.build.dataclasses import ModelConfig
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
from tensorrt_llm.bench.build.tuning import calc_engine_setting
TUNED_QUANTS = {
@ -31,6 +33,7 @@ def get_benchmark_engine_settings(
pp_size: int,
target_input_len: int,
target_output_len: int,
kv_cache_gpu_mem_fraction: float = 0.95,
) -> Tuple[int, int]:
""" Retrieve benchmark settings for a specific model + configuration.
@ -58,6 +61,7 @@ def get_benchmark_engine_settings(
pp_size,
target_input_len,
target_output_len,
kv_cache_gpu_mem_fraction,
)
else:
max_batch_size = DEFAULT_MAX_BATCH_SIZE
@ -82,6 +86,10 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
Raises:
ValueError: When model is not supported.
"""
if is_nemotron_hybrid(
AutoConfig.from_pretrained(model_path or model_name,
trust_remote_code=True)):
return NemotronHybridConfig.from_hf(model_name, model_path)
return ModelConfig.from_hf(model_name, model_path)

View File

@ -124,6 +124,7 @@ class ModelConfig(BaseModel):
AliasPath("text_config", "num_hidden_layers"),
AliasPath("language_config", "num_hidden_layers"),
))
num_attention_layers: Optional[int] = Field(default=None)
num_attention_heads: int = Field(validation_alias=AliasChoices(
"num_attention_heads",
"n_head",
@ -148,6 +149,7 @@ class ModelConfig(BaseModel):
validation_alias=AliasChoices(
"head_size",
"head_dim",
"attention_head_dim",
AliasPath("text_config", "head_dim"),
))
max_position_embeddings: Optional[int] = Field(
@ -171,6 +173,8 @@ class ModelConfig(BaseModel):
self.num_key_value_heads = self.num_attention_heads
if self.head_size is None:
self.head_size = self.hidden_size // self.num_attention_heads
if self.num_attention_layers is None:
self.num_attention_layers = self.num_hidden_layers
return self
@classmethod
@ -194,3 +198,59 @@ class ModelConfig(BaseModel):
param_count = cls.get_param_count(model_hf_name, hf_model_path)
return cls(name=model_hf_name, param_count=param_count, **hf_config)
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
return 0
def cache_memory_fraction(self, cache_memory_fraction):
return cache_memory_fraction
class NemotronHybridConfig(ModelConfig):
hybrid_override_pattern: str
d_state: int = Field(validation_alias=AliasChoices(
"d_state",
"mamba_d_state",
"ssm_state_size",
))
d_conv: int = Field(validation_alias=AliasChoices(
"d_conv",
"mamba_d_conv",
"conv_kernel",
))
expand: int = Field(validation_alias=AliasChoices(
"expand",
"mamba_expand",
))
n_groups: int
mamba_head_dim: int
d_inner: Optional[int] = Field(default=None)
mamba_num_heads: Optional[int] = Field(default=None)
num_mamba_layers: Optional[int] = Field(default=None)
@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.d_inner:
self.d_inner = self.hidden_size * self.expand
if not self.mamba_num_heads:
self.mamba_num_heads = self.d_inner // self.mamba_head_dim
if self.num_mamba_layers is None:
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
if self.num_attention_layers is None:
self.num_attention_layers = self.hybrid_override_pattern.count("*")
super().set_values_if_none()
return self
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
conv_state_elems = conv_dim * (self.d_conv - 1)
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
conv_state_elems + ssm_state_elems) / (1024**3)
return gb_per_mamba_cache
def cache_memory_fraction(self, cache_memory_fraction):
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
return cache_memory_fraction**2

View File

@ -3,7 +3,7 @@ from typing import Tuple
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.bench.build.dataclasses import ModelConfig
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
from .utils import get_device_memory
import math
@ -55,7 +55,11 @@ def calc_engine_setting(
# Each GPU in TP group has at least 1 kv head
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \
logger.info(
f"Number of attention layers: {model_config.num_attention_layers}")
gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)
# Number of GPU used for this run.
@ -70,19 +74,33 @@ def calc_engine_setting(
f"{available_memory:.2f} GB")
# Calculate max requests in KV cache based on target ISL and OSL.
kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction
kv_cache_max_tokens = kv_cache_memory / byte_per_token
kv_cache_max_requests = kv_cache_max_tokens / (target_input_len +
target_output_len)
logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB")
target_seq_len = target_input_len + target_output_len
cache_memory = available_memory * model_config.cache_memory_fraction(
kv_cache_gpu_mem_fraction)
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len)
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
gb_per_extra_cache)
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
kv_cache_memory = cache_memory - extra_cache_memory
kv_cache_max_tokens = kv_cache_memory / gb_per_token
logger.info(
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
)
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
logger.info("Estimated max number of requests in KV cache memory: "
f"{kv_cache_max_requests:.2f}")
# Fine-tune the max batch size and num token setting for performance.
max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size)
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
max_batch_size, max_num_tokens = finetune_setting(
kv_cache_max_requests,
target_input_len,
target_output_len,
pp_size,
disable_optimistic_tuning=isinstance(model_config,
NemotronHybridConfig))
# Functional and performance
if total_gpu_memory < engine_size:
@ -107,7 +125,7 @@ def calc_engine_setting(
if kv_cache_max_requests < 1:
raise RuntimeError("The amount of KV cache memory is insufficient to "
"run this model. Please try with more GPUs.")
if kv_cache_memory / n_gpus < 10.0:
if cache_memory / n_gpus < 10.0:
logger.warning(
f"The KV cache memory per GPU is less than 10 GB. "
"Performance may be undesirable. Please consider using a different "
@ -126,6 +144,7 @@ def finetune_setting(
input_len: int,
output_len: int,
pp_size: int,
disable_optimistic_tuning: bool = False,
) -> Tuple[int, int]:
""" Calculate and fine-tune the engine build settings (max batch size and
max num tokens). Both max batch size and max num tokens are fine-tuned
@ -137,6 +156,7 @@ def finetune_setting(
input_len (int): Input sequence length to compile the engine.
output_len (int): Output sequence length to compile the engine.
pp_size (int): Number of pipeline parallel stages.
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.
Returns:
Tuple[int, int]: Tuple containing fine-tuned values for engine
@ -148,13 +168,16 @@ def finetune_setting(
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)
# Fine-tune the max batch size.
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
if disable_optimistic_tuning:
max_bs = 2 * math.floor(raw_bs / 2)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
# Set min BS to be 64.
if raw_bs < 256:
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
elif raw_bs < 1024:
max_bs = 128 * math.ceil(raw_bs / 128)
else:
max_bs = 256 * math.ceil(raw_bs / 256)
# Fine-tune the max num tokens.
# Set min to 2048 to ensure Ctx/Gen overlap efficiency