mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H (#5371)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
parent
10e686466e
commit
5aa958a11a
@ -7,6 +7,7 @@ import torch
|
||||
import transformers
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
||||
from tensorrt_llm._utils import torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
@ -298,7 +299,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
model_config_cpp = ModelConfigCpp(
|
||||
vocab_size=self.pretrained_config.vocab_size,
|
||||
num_layers=self.pretrained_config.num_hidden_layers,
|
||||
num_attention_layers=self.pretrained_config.num_hidden_layers,
|
||||
num_attention_layers=self.get_num_attention_layers(),
|
||||
num_rnn_layers=0,
|
||||
num_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
@ -376,3 +377,9 @@ class ModelConfig(Generic[TConfig]):
|
||||
] * self.pretrained_config.num_hidden_layers
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_num_attention_layers(self):
|
||||
if is_nemotron_hybrid(self.pretrained_config):
|
||||
return self.pretrained_config.hybrid_override_pattern.count("*")
|
||||
else:
|
||||
return self.pretrained_config.num_hidden_layers
|
||||
|
||||
@ -82,9 +82,9 @@ class KvCacheCreator:
|
||||
) * num_key_value_heads // tp_size
|
||||
|
||||
# provide at least 1 layer to prevent division by zero cache size
|
||||
num_hidden_layers = max(
|
||||
len(mapping.pp_layers(config.num_hidden_layers)), 1)
|
||||
mem_per_token *= num_hidden_layers * head_dim
|
||||
num_attention_layers = max(
|
||||
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
|
||||
mem_per_token *= num_attention_layers * head_dim
|
||||
# K and V
|
||||
mem_per_token *= kv_factor
|
||||
return mem_per_token
|
||||
|
||||
@ -818,7 +818,7 @@ class MambaCacheManager(BaseResourceManager):
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
state_indices = []
|
||||
for r in request_ids:
|
||||
# cache hit
|
||||
@ -834,12 +834,7 @@ class MambaCacheManager(BaseResourceManager):
|
||||
self.state_indices[:len(state_indices)] = torch.as_tensor(
|
||||
state_indices, dtype=torch.int32, device=self.ssm_states.device)
|
||||
|
||||
def free_mamba_cache_blocks(self, request_id: int):
|
||||
if request_id in self.mamba_cache_index:
|
||||
block = self.mamba_cache_index.pop(request_id)
|
||||
self.mamba_cache_free_blocks.append(block)
|
||||
|
||||
def prepare_mamba_resources(self, scheduled_batch: ScheduledRequests):
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
context_ids = [
|
||||
i.py_request_id for i in scheduled_batch.context_requests
|
||||
]
|
||||
@ -847,10 +842,13 @@ class MambaCacheManager(BaseResourceManager):
|
||||
i.py_request_id for i in scheduled_batch.generation_requests
|
||||
]
|
||||
request_ids = context_ids + generation_ids
|
||||
self.prepare_mamba_cache_blocks(request_ids)
|
||||
self._prepare_mamba_cache_blocks(request_ids)
|
||||
|
||||
def free_mamba_resources(self, request: LlmRequest):
|
||||
self.free_mamba_cache_blocks(request.py_request_id)
|
||||
def free_resources(self, request: LlmRequest):
|
||||
request_id = request.py_request_id
|
||||
if request_id in self.mamba_cache_index:
|
||||
block = self.mamba_cache_index.pop(request_id)
|
||||
self.mamba_cache_free_blocks.append(block)
|
||||
|
||||
def get_state_indices(self) -> torch.Tensor:
|
||||
return self.state_indices
|
||||
@ -863,6 +861,13 @@ class MambaCacheManager(BaseResourceManager):
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.ssm_states[layer_offset]
|
||||
|
||||
def shutdown(self):
|
||||
# release tensor memory, keeping python references as tensors
|
||||
self.conv_states = torch.tensor([])
|
||||
self.ssm_states = torch.tensor([])
|
||||
self.state_indices = torch.tensor([])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
|
||||
@ -933,12 +938,16 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
)
|
||||
|
||||
def prepare_resources(self, scheduled_batch: ScheduledRequests):
|
||||
self.prepare_mamba_resources(scheduled_batch)
|
||||
super().prepare_resources(scheduled_batch)
|
||||
MambaCacheManager.prepare_resources(self, scheduled_batch)
|
||||
KVCacheManager.prepare_resources(self, scheduled_batch)
|
||||
|
||||
def free_resources(self, request: LlmRequest):
|
||||
self.free_mamba_resources(request)
|
||||
super().free_resources(request)
|
||||
MambaCacheManager.free_resources(self, request)
|
||||
KVCacheManager.free_resources(self, request)
|
||||
|
||||
def shutdown(self):
|
||||
MambaCacheManager.shutdown(self)
|
||||
KVCacheManager.shutdown(self)
|
||||
|
||||
|
||||
class SlotManager:
|
||||
|
||||
@ -130,6 +130,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
params.get("pp"),
|
||||
dataset_metadata.avg_isl,
|
||||
dataset_metadata.avg_osl,
|
||||
params.get("kv_cache_free_gpu_mem_fraction"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from transformers import AutoConfig
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Tuple, get_args
|
||||
import click
|
||||
from click_option_group import AllOptionGroup, optgroup
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_nemotron_hybrid
|
||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||
from tensorrt_llm.bench.utils.data import create_dataset_from_stream, initialize_tokenizer
|
||||
from tensorrt_llm.bench.utils import VALID_QUANT_ALGOS
|
||||
@ -13,7 +15,7 @@ from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm.llmapi.llm_utils import QuantConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
from tensorrt_llm.bench.build.dataclasses import ModelConfig
|
||||
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
|
||||
from tensorrt_llm.bench.build.tuning import calc_engine_setting
|
||||
|
||||
TUNED_QUANTS = {
|
||||
@ -31,6 +33,7 @@ def get_benchmark_engine_settings(
|
||||
pp_size: int,
|
||||
target_input_len: int,
|
||||
target_output_len: int,
|
||||
kv_cache_gpu_mem_fraction: float = 0.95,
|
||||
) -> Tuple[int, int]:
|
||||
""" Retrieve benchmark settings for a specific model + configuration.
|
||||
|
||||
@ -58,6 +61,7 @@ def get_benchmark_engine_settings(
|
||||
pp_size,
|
||||
target_input_len,
|
||||
target_output_len,
|
||||
kv_cache_gpu_mem_fraction,
|
||||
)
|
||||
else:
|
||||
max_batch_size = DEFAULT_MAX_BATCH_SIZE
|
||||
@ -82,6 +86,10 @@ def get_model_config(model_name: str, model_path: Path = None) -> ModelConfig:
|
||||
Raises:
|
||||
ValueError: When model is not supported.
|
||||
"""
|
||||
if is_nemotron_hybrid(
|
||||
AutoConfig.from_pretrained(model_path or model_name,
|
||||
trust_remote_code=True)):
|
||||
return NemotronHybridConfig.from_hf(model_name, model_path)
|
||||
return ModelConfig.from_hf(model_name, model_path)
|
||||
|
||||
|
||||
|
||||
@ -124,6 +124,7 @@ class ModelConfig(BaseModel):
|
||||
AliasPath("text_config", "num_hidden_layers"),
|
||||
AliasPath("language_config", "num_hidden_layers"),
|
||||
))
|
||||
num_attention_layers: Optional[int] = Field(default=None)
|
||||
num_attention_heads: int = Field(validation_alias=AliasChoices(
|
||||
"num_attention_heads",
|
||||
"n_head",
|
||||
@ -148,6 +149,7 @@ class ModelConfig(BaseModel):
|
||||
validation_alias=AliasChoices(
|
||||
"head_size",
|
||||
"head_dim",
|
||||
"attention_head_dim",
|
||||
AliasPath("text_config", "head_dim"),
|
||||
))
|
||||
max_position_embeddings: Optional[int] = Field(
|
||||
@ -171,6 +173,8 @@ class ModelConfig(BaseModel):
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
if self.head_size is None:
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
if self.num_attention_layers is None:
|
||||
self.num_attention_layers = self.num_hidden_layers
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@ -194,3 +198,59 @@ class ModelConfig(BaseModel):
|
||||
param_count = cls.get_param_count(model_hf_name, hf_model_path)
|
||||
|
||||
return cls(name=model_hf_name, param_count=param_count, **hf_config)
|
||||
|
||||
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
|
||||
return 0
|
||||
|
||||
def cache_memory_fraction(self, cache_memory_fraction):
|
||||
return cache_memory_fraction
|
||||
|
||||
|
||||
class NemotronHybridConfig(ModelConfig):
|
||||
hybrid_override_pattern: str
|
||||
d_state: int = Field(validation_alias=AliasChoices(
|
||||
"d_state",
|
||||
"mamba_d_state",
|
||||
"ssm_state_size",
|
||||
))
|
||||
d_conv: int = Field(validation_alias=AliasChoices(
|
||||
"d_conv",
|
||||
"mamba_d_conv",
|
||||
"conv_kernel",
|
||||
))
|
||||
expand: int = Field(validation_alias=AliasChoices(
|
||||
"expand",
|
||||
"mamba_expand",
|
||||
))
|
||||
n_groups: int
|
||||
mamba_head_dim: int
|
||||
d_inner: Optional[int] = Field(default=None)
|
||||
mamba_num_heads: Optional[int] = Field(default=None)
|
||||
num_mamba_layers: Optional[int] = Field(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_values_if_none(self):
|
||||
""" Set the values if cannot get values from HF config.json. """
|
||||
if not self.d_inner:
|
||||
self.d_inner = self.hidden_size * self.expand
|
||||
if not self.mamba_num_heads:
|
||||
self.mamba_num_heads = self.d_inner // self.mamba_head_dim
|
||||
if self.num_mamba_layers is None:
|
||||
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
|
||||
if self.num_attention_layers is None:
|
||||
self.num_attention_layers = self.hybrid_override_pattern.count("*")
|
||||
|
||||
super().set_values_if_none()
|
||||
return self
|
||||
|
||||
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
|
||||
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
|
||||
conv_state_elems = conv_dim * (self.d_conv - 1)
|
||||
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
|
||||
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
|
||||
conv_state_elems + ssm_state_elems) / (1024**3)
|
||||
return gb_per_mamba_cache
|
||||
|
||||
def cache_memory_fraction(self, cache_memory_fraction):
|
||||
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
|
||||
return cache_memory_fraction**2
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Tuple
|
||||
from tensorrt_llm.llmapi.llm_utils import QuantConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
from tensorrt_llm.bench.build.dataclasses import ModelConfig
|
||||
from tensorrt_llm.bench.build.dataclasses import ModelConfig, NemotronHybridConfig
|
||||
from .utils import get_device_memory
|
||||
import math
|
||||
|
||||
@ -55,7 +55,11 @@ def calc_engine_setting(
|
||||
|
||||
# Each GPU in TP group has at least 1 kv head
|
||||
adjusted_num_kv_heads = max(tp_size, model_config.num_key_value_heads)
|
||||
byte_per_token = 2 * model_config.num_hidden_layers * adjusted_num_kv_heads \
|
||||
|
||||
logger.info(
|
||||
f"Number of attention layers: {model_config.num_attention_layers}")
|
||||
|
||||
gb_per_token = 2 * model_config.num_attention_layers * adjusted_num_kv_heads \
|
||||
* model_config.head_size * byte_per_kv_elem / (1024 ** 3)
|
||||
|
||||
# Number of GPU used for this run.
|
||||
@ -70,19 +74,33 @@ def calc_engine_setting(
|
||||
f"{available_memory:.2f} GB")
|
||||
|
||||
# Calculate max requests in KV cache based on target ISL and OSL.
|
||||
kv_cache_memory = available_memory * kv_cache_gpu_mem_fraction
|
||||
kv_cache_max_tokens = kv_cache_memory / byte_per_token
|
||||
kv_cache_max_requests = kv_cache_max_tokens / (target_input_len +
|
||||
target_output_len)
|
||||
logger.info(f"Estimated total KV cache memory: {kv_cache_memory:.2f} GB")
|
||||
target_seq_len = target_input_len + target_output_len
|
||||
cache_memory = available_memory * model_config.cache_memory_fraction(
|
||||
kv_cache_gpu_mem_fraction)
|
||||
gb_per_extra_cache = model_config.extra_model_cache_in_gb(
|
||||
BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len)
|
||||
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
|
||||
gb_per_extra_cache)
|
||||
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
|
||||
kv_cache_memory = cache_memory - extra_cache_memory
|
||||
kv_cache_max_tokens = kv_cache_memory / gb_per_token
|
||||
|
||||
logger.info(
|
||||
f"Estimated total cache memory: {cache_memory:.2f} GB. KV cache: {kv_cache_memory:.2f} GB, Extra cache: {extra_cache_memory:.2f} GB"
|
||||
)
|
||||
logger.info(f"Estimated kv cache max tokens: {kv_cache_max_tokens:.2f}")
|
||||
logger.info("Estimated max number of requests in KV cache memory: "
|
||||
f"{kv_cache_max_requests:.2f}")
|
||||
|
||||
# Fine-tune the max batch size and num token setting for performance.
|
||||
max_batch_size, max_num_tokens = finetune_setting(kv_cache_max_requests,
|
||||
target_input_len,
|
||||
target_output_len,
|
||||
pp_size)
|
||||
# For mamba-attn hybrid models, we disable optimistic tuning because the mamba cache leaves less memory for the KV cache
|
||||
max_batch_size, max_num_tokens = finetune_setting(
|
||||
kv_cache_max_requests,
|
||||
target_input_len,
|
||||
target_output_len,
|
||||
pp_size,
|
||||
disable_optimistic_tuning=isinstance(model_config,
|
||||
NemotronHybridConfig))
|
||||
|
||||
# Functional and performance
|
||||
if total_gpu_memory < engine_size:
|
||||
@ -107,7 +125,7 @@ def calc_engine_setting(
|
||||
if kv_cache_max_requests < 1:
|
||||
raise RuntimeError("The amount of KV cache memory is insufficient to "
|
||||
"run this model. Please try with more GPUs.")
|
||||
if kv_cache_memory / n_gpus < 10.0:
|
||||
if cache_memory / n_gpus < 10.0:
|
||||
logger.warning(
|
||||
f"The KV cache memory per GPU is less than 10 GB. "
|
||||
"Performance may be undesirable. Please consider using a different "
|
||||
@ -126,6 +144,7 @@ def finetune_setting(
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
pp_size: int,
|
||||
disable_optimistic_tuning: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
""" Calculate and fine-tune the engine build settings (max batch size and
|
||||
max num tokens). Both max batch size and max num tokens are fine-tuned
|
||||
@ -137,6 +156,7 @@ def finetune_setting(
|
||||
input_len (int): Input sequence length to compile the engine.
|
||||
output_len (int): Output sequence length to compile the engine.
|
||||
pp_size (int): Number of pipeline parallel stages.
|
||||
disable_optimistic_tuning (bool): Whether to disable optimistic tuning.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: Tuple containing fine-tuned values for engine
|
||||
@ -148,13 +168,16 @@ def finetune_setting(
|
||||
raw_token = min(raw_bs * (1 + input_len / output_len), 32768)
|
||||
|
||||
# Fine-tune the max batch size.
|
||||
# Set min BS to be 64.
|
||||
if raw_bs < 256:
|
||||
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
|
||||
elif raw_bs < 1024:
|
||||
max_bs = 128 * math.ceil(raw_bs / 128)
|
||||
if disable_optimistic_tuning:
|
||||
max_bs = 2 * math.floor(raw_bs / 2)
|
||||
else:
|
||||
max_bs = 256 * math.ceil(raw_bs / 256)
|
||||
# Set min BS to be 64.
|
||||
if raw_bs < 256:
|
||||
max_bs = max(64, 32 * math.ceil(raw_bs / 32))
|
||||
elif raw_bs < 1024:
|
||||
max_bs = 128 * math.ceil(raw_bs / 128)
|
||||
else:
|
||||
max_bs = 256 * math.ceil(raw_bs / 256)
|
||||
|
||||
# Fine-tune the max num tokens.
|
||||
# Set min to 2048 to ensure Ctx/Gen overlap efficiency
|
||||
|
||||
Loading…
Reference in New Issue
Block a user