TensorRT-LLMs/tensorrt_llm/bench/build/dataclasses.py
tomeras91 5aa958a11a
[TRTLLM-5838][fix] fix max batch size and max tokens in kv cache estimations for Nemotron-H (#5371)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
2025-07-09 11:30:15 +03:00

257 lines
10 KiB
Python
Executable File

from transformers import AutoConfig
from typing import Optional, Literal
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
import huggingface_hub
from huggingface_hub.constants import (
SAFETENSORS_INDEX_FILE,
SAFETENSORS_MAX_HEADER_LENGTH,
SAFETENSORS_SINGLE_FILE,
)
from huggingface_hub.utils import SafetensorsRepoMetadata, SafetensorsFileMetadata, TensorInfo
from huggingface_hub.utils import tqdm as hf_tqdm
from tqdm.contrib.concurrent import thread_map
import os
import json
import struct
def parse_safetensors_file_metadata(model_path, filename):
with open(os.path.join(model_path, filename), "rb") as f:
metadata_size = f.read(8)
metadata_size = struct.unpack("<Q", metadata_size)[0]
if metadata_size > SAFETENSORS_MAX_HEADER_LENGTH:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
f"safetensors header is too big. Maximum supported size is "
f"{SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size}).")
metadata_as_bytes = f.read(metadata_size)
try:
metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore"))
except json.JSONDecodeError as e:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
"header format not recognized. Please make sure this is a correctly formatted safetensors file."
) from e
try:
return SafetensorsFileMetadata(
metadata=metadata_as_dict.get("__metadata__", {}),
tensors={
key:
TensorInfo(
dtype=tensor["dtype"],
shape=tensor["shape"],
data_offsets=tuple(tensor["data_offsets"]), # type: ignore
)
for key, tensor in metadata_as_dict.items()
if key != "__metadata__"
},
)
except (KeyError, IndexError) as e:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
"header format not recognized. Please make sure this is a correctly formatted safetensors file."
) from e
def get_safetensors_metadata(model_name_or_path):
""" Read the safetensors metadata from HF model. """
if os.path.isdir(model_name_or_path):
if os.path.exists(
os.path.join(model_name_or_path, SAFETENSORS_SINGLE_FILE)):
file_metadata = parse_safetensors_file_metadata(
model_path=model_name_or_path, filename=SAFETENSORS_SINGLE_FILE)
return SafetensorsRepoMetadata(
metadata=None,
sharded=False,
weight_map={
tensor_name: SAFETENSORS_SINGLE_FILE
for tensor_name in file_metadata.tensors.keys()
},
files_metadata={SAFETENSORS_SINGLE_FILE: file_metadata},
)
elif os.path.exists(
os.path.join(model_name_or_path, SAFETENSORS_INDEX_FILE)):
with open(os.path.join(model_name_or_path,
SAFETENSORS_INDEX_FILE)) as f:
index = json.load(f)
weight_map = index.get("weight_map", {})
# Fetch metadata per shard
files_metadata = {}
def _parse(filename: str) -> None:
files_metadata[filename] = parse_safetensors_file_metadata(
model_path=model_name_or_path, filename=filename)
thread_map(
_parse,
set(weight_map.values()),
desc="Parse safetensors files",
tqdm_class=hf_tqdm,
)
return SafetensorsRepoMetadata(
metadata=index.get("metadata", None),
sharded=True,
weight_map=weight_map,
files_metadata=files_metadata,
)
else:
# Not a safetensors repo
raise RuntimeError(
f"'{model_name_or_path}' is not a safetensors repo. Couldn't find '{SAFETENSORS_INDEX_FILE}' or '{SAFETENSORS_SINGLE_FILE}' files."
)
else:
return huggingface_hub.get_safetensors_metadata(model_name_or_path)
class ModelConfig(BaseModel):
""" Model specific configurations. The parameters are needed in engine
setting calculation.
"""
name: str
model_type: str
param_count: int
num_hidden_layers: int = Field(validation_alias=AliasChoices(
"num_hidden_layers",
"n_layer",
AliasPath("text_config", "num_hidden_layers"),
AliasPath("language_config", "num_hidden_layers"),
))
num_attention_layers: Optional[int] = Field(default=None)
num_attention_heads: int = Field(validation_alias=AliasChoices(
"num_attention_heads",
"n_head",
AliasPath("text_config", "num_attention_heads"),
AliasPath("language_config", "num_attention_heads"),
))
num_key_value_heads: Optional[int] = Field(
default=None,
validation_alias=AliasChoices(
"num_key_value_heads",
"num_kv_heads",
AliasPath("text_config", "num_key_value_heads"),
AliasPath("language_config", "num_key_value_heads"),
),
)
hidden_size: int = Field(validation_alias=AliasChoices(
"hidden_size",
"n_embd",
AliasPath("text_config", "hidden_size"),
))
head_size: Optional[int] = Field(default=None,
validation_alias=AliasChoices(
"head_size",
"head_dim",
"attention_head_dim",
AliasPath("text_config", "head_dim"),
))
max_position_embeddings: Optional[int] = Field(
default=None,
validation_alias=AliasChoices(
"max_position_embeddings",
"n_positions",
AliasPath("text_config", "max_position_embeddings"),
))
dtype: Literal["float16", "bfloat16", "float32",
None] = Field(default="float16",
validation_alias=AliasChoices(
"dtype", "torch_dtype"))
@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.dtype: # for GPT-J
self.dtype = "float16"
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.head_size is None:
self.head_size = self.hidden_size // self.num_attention_heads
if self.num_attention_layers is None:
self.num_attention_layers = self.num_hidden_layers
return self
@classmethod
def get_param_count(cls, model_hf_name, hf_model_path):
""" Read the parameter count from HF safetensor metadata. """
if model_hf_name == "EleutherAI/gpt-j-6b": # GPT-J repo doesn't use safetensor format.
param_count = 6053381344
else:
model_name_or_path = hf_model_path or model_hf_name
metadata = get_safetensors_metadata(model_name_or_path)
param_count = sum(metadata.parameter_count.values())
assert param_count, f"Can't get valid parameter count for model: {model_name_or_path}."
return param_count
@classmethod
def from_hf(cls, model_hf_name, hf_model_path):
model_name_or_path = hf_model_path or model_hf_name
hf_config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True).to_dict()
param_count = cls.get_param_count(model_hf_name, hf_model_path)
return cls(name=model_hf_name, param_count=param_count, **hf_config)
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
return 0
def cache_memory_fraction(self, cache_memory_fraction):
return cache_memory_fraction
class NemotronHybridConfig(ModelConfig):
hybrid_override_pattern: str
d_state: int = Field(validation_alias=AliasChoices(
"d_state",
"mamba_d_state",
"ssm_state_size",
))
d_conv: int = Field(validation_alias=AliasChoices(
"d_conv",
"mamba_d_conv",
"conv_kernel",
))
expand: int = Field(validation_alias=AliasChoices(
"expand",
"mamba_expand",
))
n_groups: int
mamba_head_dim: int
d_inner: Optional[int] = Field(default=None)
mamba_num_heads: Optional[int] = Field(default=None)
num_mamba_layers: Optional[int] = Field(default=None)
@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.d_inner:
self.d_inner = self.hidden_size * self.expand
if not self.mamba_num_heads:
self.mamba_num_heads = self.d_inner // self.mamba_head_dim
if self.num_mamba_layers is None:
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
if self.num_attention_layers is None:
self.num_attention_layers = self.hybrid_override_pattern.count("*")
super().set_values_if_none()
return self
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
conv_state_elems = conv_dim * (self.d_conv - 1)
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
conv_state_elems + ssm_state_elems) / (1024**3)
return gb_per_mamba_cache
def cache_memory_fraction(self, cache_memory_fraction):
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
return cache_memory_fraction**2