mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
653 lines
29 KiB
Python
653 lines
29 KiB
Python
import contextlib
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Dict, Generic, List, Optional, TypeVar
|
|
|
|
import filelock
|
|
import torch
|
|
import transformers
|
|
from transformers.utils import HF_MODULES_CACHE
|
|
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
|
|
load_pretrained_config)
|
|
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
|
|
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
|
|
from tensorrt_llm.functional import AllReduceStrategy
|
|
from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig,
|
|
MoeLoadBalancerConfig)
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
|
|
|
TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def config_file_lock(timeout: int = 10):
|
|
"""
|
|
Context manager for file locking when loading pretrained configs.
|
|
|
|
This prevents race conditions when multiple processes try to download/load
|
|
the same model configuration simultaneously.
|
|
|
|
Args:
|
|
timeout: Maximum time to wait for lock acquisition in seconds
|
|
"""
|
|
# Use a single global lock file in HF cache directory
|
|
# This serializes all model loading operations to prevent race conditions
|
|
lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock"
|
|
|
|
# Create and acquire the lock
|
|
lock = filelock.FileLock(str(lock_path), timeout=timeout)
|
|
|
|
try:
|
|
with lock:
|
|
yield
|
|
except (PermissionError, filelock.Timeout):
|
|
# Fallback to tempdir
|
|
tmp_dir = Path(tempfile.gettempdir())
|
|
tmp_dir.mkdir(parents=True, exist_ok=True)
|
|
tmp_lock_path = tmp_dir / "_remote_code.lock"
|
|
tmp_lock = filelock.FileLock(str(tmp_lock_path), timeout=timeout)
|
|
try:
|
|
with tmp_lock:
|
|
yield
|
|
except filelock.Timeout:
|
|
logger.warning(
|
|
f"failed to acquire tempdir config lock within {timeout} seconds, proceeding without lock"
|
|
)
|
|
# proceed without lock
|
|
yield
|
|
except (PermissionError) as e:
|
|
logger.warning(
|
|
f"tempdir config lock unavailable due to OS/permission issue: {e}, proceeding without lock"
|
|
)
|
|
# proceed without lock
|
|
yield
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class ModelConfig(Generic[TConfig]):
|
|
pretrained_config: Optional[TConfig] = None
|
|
mapping: Mapping = field(default_factory=Mapping)
|
|
|
|
# Quantization configs
|
|
quant_config: QuantConfig = field(default_factory=QuantConfig)
|
|
# Per linear layer quantization in quant_cfg.json or hf_quant_config.json
|
|
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
|
|
# Delay weights creation to DecoderModelForCausalLM.__post_init__
|
|
# to support mixed quantization.
|
|
skip_create_weights_in_init: bool = False
|
|
|
|
spec_config: Optional["DecodingBaseConfig"] = None
|
|
lora_config: Optional["LoraConfig"] = None
|
|
sparse_attention_config: Optional["SparseAttentionConfig"] = None
|
|
|
|
is_generation: bool = True
|
|
max_num_tokens: int = 8192
|
|
max_seq_len: Optional[int] = None
|
|
|
|
moe_max_num_tokens: Optional[int] = None
|
|
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None
|
|
|
|
attn_backend: str = 'TRTLLM'
|
|
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
|
|
# IF true, disables FC2+finalize fusion in CUTLASS MoE backend
|
|
moe_disable_finalize_fusion: bool = False
|
|
# If true, use low precision combine in MoE operations (only for NVFP4 quantization)
|
|
use_low_precision_moe_combine: bool = False
|
|
|
|
# NVFP4 GEMM backend configuration - list of backends to consider for auto-selection
|
|
# Default excludes 'cutedsl' for faster build time. Add 'cutedsl' for extreme perf.
|
|
nvfp4_gemm_allowed_backends: List[str] = field(
|
|
default_factory=lambda: ['cutlass', 'cublaslt', 'cuda_core'])
|
|
|
|
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
|
|
|
|
# If true, enable min-latency mode. Currently only used for Llama4.
|
|
enable_min_latency: bool = False
|
|
|
|
# Allow models to select op according to whether CUDA Graphs are used.
|
|
use_cuda_graph: bool = False
|
|
|
|
force_dynamic_quantization: bool = False
|
|
|
|
# If true, use torch.compile for embedding layers.
|
|
enable_torch_compile_for_embedding = False
|
|
|
|
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
|
|
|
_frozen: bool = field(default=False, init=False, repr=False)
|
|
|
|
# If true, ONLY the vision encoder part of the full model is loaded/executed.
|
|
mm_encoder_only: bool = False
|
|
|
|
def __setattr__(self, key, value):
|
|
"""
|
|
Prevent modification of frozen instance attributes.
|
|
However, we allow modification of 'extra_attrs' attributes for torch.compile
|
|
and 'pretrained_config' attributes for mutimodal models.
|
|
'quant_config' is allowed to be modified to set different quantization for VLM.
|
|
All the other attributes are frozen.
|
|
This can be bypassed by manually setting '_frozen' to False. The design is
|
|
to discourage modifying the attributes unintentionally.
|
|
"""
|
|
if self._frozen:
|
|
if key not in ('_frozen', 'extra_attrs', 'pretrained_config',
|
|
'quant_config'):
|
|
raise AttributeError(
|
|
f"Cannot modify ModelConfig.'{key}' - instance is frozen")
|
|
super().__setattr__(key, value)
|
|
|
|
def __post_init__(self):
|
|
if self.pretrained_config and hasattr(self.pretrained_config,
|
|
"architectures"):
|
|
self.is_generation = self.is_generation_model(
|
|
self.pretrained_config.architectures,
|
|
mm_encoder_only=self.mm_encoder_only)
|
|
|
|
def get_all_reduce_strategy(strategy: str = "AUTO"):
|
|
maps = {
|
|
"AUTO": AllReduceStrategy.AUTO,
|
|
"NCCL": AllReduceStrategy.NCCL,
|
|
"UB": AllReduceStrategy.UB,
|
|
"MINLATENCY": AllReduceStrategy.MIN_LATENCY,
|
|
"ONESHOT": AllReduceStrategy.ONESHOT,
|
|
"TWOSHOT": AllReduceStrategy.TWOSHOT,
|
|
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
|
|
"MNNVL": AllReduceStrategy.MNNVL,
|
|
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC,
|
|
}
|
|
key = strategy.upper()
|
|
return maps[key] if key in maps else AllReduceStrategy.AUTO
|
|
|
|
if isinstance(self.allreduce_strategy, str):
|
|
self.allreduce_strategy = get_all_reduce_strategy(
|
|
self.allreduce_strategy)
|
|
|
|
# Set default moe_max_num_tokens if not specified
|
|
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
|
if self.moe_max_num_tokens is None:
|
|
self.moe_max_num_tokens = self.max_num_tokens * self.mapping.dp_size
|
|
|
|
@property
|
|
def torch_dtype(self) -> torch.dtype:
|
|
"""Get the torch dtype of the model."""
|
|
# TODO: this is an assumption that a HF model is always in bfloat16
|
|
# We should figure out a better way to handle this if other models
|
|
# start to not report dtype.
|
|
return self.pretrained_config.torch_dtype or torch.bfloat16
|
|
|
|
@property
|
|
def fuse_pos_embd(self):
|
|
if self.attn_backend == 'TRTLLM':
|
|
return True
|
|
elif self.attn_backend == 'FLASHINFER':
|
|
return False
|
|
return False
|
|
|
|
@property
|
|
def enable_flash_mla(self):
|
|
if self.attn_backend == 'TRTLLM':
|
|
if getattr(self.pretrained_config,
|
|
"kv_lora_rank", None) and getattr(
|
|
self.pretrained_config, "qk_rope_head_dim", None):
|
|
head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim
|
|
if head_dim == 576 and torch.cuda.get_device_capability() == (
|
|
9, 0):
|
|
return True
|
|
return False
|
|
|
|
def get_quant_config(self, name: Optional[str] = None) -> QuantConfig:
|
|
if name is None or self.per_layer_quant_configs is None:
|
|
return self.quant_config
|
|
|
|
if name in self.per_layer_quant_configs:
|
|
return self.per_layer_quant_configs[name]
|
|
|
|
raise ValueError(f'quant config of {name} is not found')
|
|
|
|
@staticmethod
|
|
def is_generation_model(model_architectures: Optional[List[str]],
|
|
mm_encoder_only: bool = False) -> bool:
|
|
if model_architectures is None:
|
|
logger.warning(
|
|
"Model architectures is None, default to is_generation_model=True"
|
|
)
|
|
return True
|
|
if mm_encoder_only:
|
|
return False
|
|
return model_architectures[0] not in [
|
|
"BertForSequenceClassification", "Qwen2ForProcessRewardModel",
|
|
"Qwen2ForRewardModel", "LlamaForTextEmbedding"
|
|
]
|
|
# TODO: should be 'not model_type == ModelType.ENCODER_ONLY'
|
|
# once ModelType is used in pytorch flow.
|
|
|
|
@staticmethod
|
|
def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
|
|
moe_backend):
|
|
quant_config = QuantConfig()
|
|
layer_quant_config = None
|
|
|
|
with open(quant_config_file) as f:
|
|
quant_config_dict = json.load(f)
|
|
|
|
json_quant_configs = quant_config_dict['quantization']
|
|
|
|
quant_config.quant_algo = json_quant_configs.get('quant_algo', None)
|
|
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
|
|
if quant_config.quant_algo == "fp8_pb_wo":
|
|
quant_config.quant_algo = 'FP8_BLOCK_SCALES'
|
|
quant_config.kv_cache_quant_algo = json_quant_configs.get(
|
|
'kv_cache_quant_algo', None)
|
|
quant_config.group_size = json_quant_configs.get('group_size', None)
|
|
quant_config.exclude_modules = json_quant_configs.get(
|
|
'exclude_modules', None)
|
|
|
|
if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
|
|
json_extended_quant_configs: dict = {}
|
|
# See tests/unittest/llmapi/test_llm_quant.py
|
|
try:
|
|
mixed_quant_config_file = transformers.utils.hub.cached_file(
|
|
checkpoint_dir, 'quant_cfg.json')
|
|
with open(mixed_quant_config_file) as fm:
|
|
json_extended_quant_configs = json.load(fm)
|
|
except Exception:
|
|
logger.info(
|
|
f"No quant_cfg.json found for layer quant info, using hf_quant_config.json."
|
|
)
|
|
json_quant_configs.update(json_extended_quant_configs)
|
|
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
|
|
kv_cache_quant_algo = json_quant_configs.get(
|
|
'kv_cache_quant_algo', None)
|
|
mixed_quant_configs = json_quant_configs.get(
|
|
'quantized_layers', None)
|
|
if (kv_quant_lhs := json_extended_quant_configs.get(
|
|
"kv_cache_quant_algo", None)) is not None and (
|
|
kv_quant_rhs :=
|
|
quant_config.kv_cache_quant_algo) is not None:
|
|
if kv_quant_lhs != kv_quant_rhs:
|
|
raise RuntimeError(
|
|
f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs},"
|
|
f"is different from 'hf_quant_config.json', {kv_quant_rhs}!"
|
|
)
|
|
quant_config.kv_cache_quant_algo = json_quant_configs[
|
|
"kv_cache_quant_algo"]
|
|
for layer in mixed_quant_configs:
|
|
config = QuantConfig()
|
|
config.kv_cache_quant_algo = kv_cache_quant_algo
|
|
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
|
|
config.group_size = mixed_quant_configs[layer].get(
|
|
'group_size', None)
|
|
mixed_quant_configs[layer] = config
|
|
layer_quant_config = mixed_quant_configs
|
|
elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
|
if quant_config.group_size is None:
|
|
quant_config.group_size = 128
|
|
|
|
if moe_backend == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None:
|
|
quant_config.exclude_modules = [
|
|
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
|
|
]
|
|
return quant_config, layer_quant_config
|
|
|
|
@staticmethod
|
|
def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False):
|
|
quant_algo = ModelConfig.override_quant_algo()
|
|
if quant_algo is None and not is_dynamic_quant:
|
|
if get_sm_version() >= 100:
|
|
if moe_backend == 'TRITON':
|
|
return QuantAlgo.W4A8_MXFP4_FP8
|
|
else:
|
|
return QuantAlgo.W4A8_MXFP4_MXFP8
|
|
else:
|
|
return QuantAlgo.W4A16_MXFP4
|
|
else:
|
|
return quant_algo
|
|
|
|
@staticmethod
|
|
def load_hf_quant_config(hf_quant_config, moe_backend):
|
|
quant_config = QuantConfig()
|
|
layer_quant_config = None
|
|
|
|
# Read exclude_modules from HF config if present (HF format module names)
|
|
hf_exclude_modules = hf_quant_config.get('modules_to_not_convert', None)
|
|
|
|
# DeepSeek V3 FP8 ckpt
|
|
if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get(
|
|
"weight_block_size", []):
|
|
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
|
|
|
block_size = hf_quant_config.get("weight_block_size", [])
|
|
assert tuple(block_size) == (
|
|
128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
|
|
quant_config.group_size = block_size[0]
|
|
|
|
# Set default exclude_modules for FP8_BLOCK_SCALES
|
|
if moe_backend == 'TRTLLM':
|
|
default_exclude = ["*kv_b_proj*", "*k_b_proj*", "*eh_proj"]
|
|
else:
|
|
default_exclude = ["*eh_proj"]
|
|
|
|
# Merge HF config's modules_to_not_convert with default exclude_modules
|
|
if hf_exclude_modules is not None:
|
|
quant_config.exclude_modules = list(
|
|
set(hf_exclude_modules + default_exclude))
|
|
else:
|
|
quant_config.exclude_modules = default_exclude
|
|
# MXFP4 checkpoints.
|
|
elif hf_quant_config.get("quant_method") == "mxfp4":
|
|
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
|
|
moe_backend)
|
|
quant_config.group_size = 32
|
|
|
|
# Default exclude_modules for MXFP4 (TRTLLM internal format)
|
|
default_exclude = [
|
|
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
|
|
'embedding', 'unembedding'
|
|
]
|
|
|
|
# Merge HF config's modules_to_not_convert with default exclude_modules
|
|
if hf_exclude_modules is not None:
|
|
quant_config.exclude_modules = list(
|
|
set(hf_exclude_modules + default_exclude))
|
|
else:
|
|
quant_config.exclude_modules = default_exclude
|
|
|
|
return quant_config, layer_quant_config
|
|
|
|
@staticmethod
|
|
def load_quant_config_from_dtypes_json(dtypes_json_file, moe_backend: str):
|
|
quant_config = QuantConfig()
|
|
layer_quant_config = None
|
|
|
|
exclude_modules = set()
|
|
has_mxfp4 = False
|
|
is_dynamic_quant = False
|
|
with open(dtypes_json_file) as f:
|
|
dtypes_json = json.load(f)
|
|
for layer, dtype in dtypes_json.items():
|
|
if layer.endswith("weight"):
|
|
if dtype == "BF16" or dtype == "FP16":
|
|
names = layer.split(".")
|
|
exclude_modules.add('.'.join(names[:-1]))
|
|
elif dtype == "MXFP4":
|
|
# This is the path for the fp8 checkpoint which requires dynamic quantization.
|
|
is_dynamic_quant = True
|
|
has_mxfp4 = True
|
|
elif layer.endswith("weight.blocks"):
|
|
scale_name = layer.replace("weight.blocks", "weight.scales")
|
|
scale_dtype = dtypes_json.get(scale_name, None)
|
|
assert scale_dtype == "UE8"
|
|
is_dynamic_quant = False
|
|
has_mxfp4 = True
|
|
|
|
if has_mxfp4:
|
|
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
|
|
moe_backend, is_dynamic_quant)
|
|
quant_config.group_size = 32
|
|
quant_config.exclude_modules = list(exclude_modules)
|
|
logger.info(f"Setting quant_config: {quant_config}")
|
|
|
|
return quant_config, layer_quant_config
|
|
|
|
@staticmethod
|
|
def override_quant_algo():
|
|
new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None)
|
|
supported_algos = {
|
|
"W4A16_MXFP4": QuantAlgo.W4A16_MXFP4,
|
|
"W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8,
|
|
"W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8,
|
|
}
|
|
if new_algo is not None:
|
|
if new_algo.upper() in supported_algos:
|
|
return supported_algos[new_algo.upper()]
|
|
else:
|
|
logger.warning(
|
|
f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}"
|
|
)
|
|
return None
|
|
|
|
@classmethod
|
|
def from_pretrained(cls,
|
|
checkpoint_dir: str,
|
|
trust_remote_code=False,
|
|
**kwargs):
|
|
# Use file lock to prevent race conditions when multiple processes
|
|
# try to import/cache the same remote model config file
|
|
with config_file_lock():
|
|
# When handling the case where model_format is TLLM_ENGINE
|
|
# send cyclic requests to the NONE URL.
|
|
if checkpoint_dir is not None:
|
|
pretrained_config = load_pretrained_config(
|
|
checkpoint_dir,
|
|
trust_remote_code=trust_remote_code,
|
|
**kwargs,
|
|
)
|
|
if pretrained_config.architectures[
|
|
0] == "DeepseekV32ForCausalLM":
|
|
sparse_attention_config = kwargs.get(
|
|
'sparse_attention_config')
|
|
if sparse_attention_config:
|
|
index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads
|
|
index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim
|
|
index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk
|
|
indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size
|
|
skip_indexer_for_short_seqs = sparse_attention_config.skip_indexer_for_short_seqs
|
|
else:
|
|
index_n_heads = pretrained_config.index_n_heads
|
|
index_head_dim = pretrained_config.index_head_dim
|
|
index_topk = pretrained_config.index_topk
|
|
indexer_max_chunk_size = None
|
|
skip_indexer_for_short_seqs = True
|
|
kwargs[
|
|
'sparse_attention_config'] = DeepSeekSparseAttentionConfig(
|
|
index_n_heads=index_n_heads,
|
|
index_head_dim=index_head_dim,
|
|
index_topk=index_topk,
|
|
indexer_max_chunk_size=indexer_max_chunk_size,
|
|
skip_indexer_for_short_seqs=
|
|
skip_indexer_for_short_seqs)
|
|
else:
|
|
raise ValueError(
|
|
"checkpoint_dir is None. Cannot load model config without a valid checkpoint directory."
|
|
)
|
|
|
|
# Get cached file from path or repo id, return None if not exists.
|
|
def cached_file(path_or_repo_id, file_name):
|
|
try:
|
|
return transformers.utils.hub.cached_file(
|
|
path_or_repo_id, file_name)
|
|
except OSError:
|
|
return None
|
|
|
|
# Some checkpoints lack torch_dtype, populate with dtype
|
|
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
|
|
None)
|
|
quant_config = QuantConfig()
|
|
layer_quant_config = None
|
|
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
|
|
|
|
# quantized ckpt in modelopt format
|
|
if quant_config_file := cached_file(checkpoint_dir,
|
|
'hf_quant_config.json'):
|
|
quant_config, layer_quant_config = cls.load_modelopt_quant_config(
|
|
quant_config_file, checkpoint_dir, moe_backend)
|
|
# quantized ckpt in other formats
|
|
elif hasattr(pretrained_config, "quantization_config"):
|
|
hf_quant_config = pretrained_config.quantization_config
|
|
quant_config, layer_quant_config = cls.load_hf_quant_config(
|
|
hf_quant_config, moe_backend)
|
|
elif quant_config_file := cached_file(checkpoint_dir, 'dtypes.json'):
|
|
quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json(
|
|
quant_config_file, moe_backend)
|
|
|
|
model_config = cls(pretrained_config=pretrained_config,
|
|
quant_config=quant_config,
|
|
quant_config_dict=layer_quant_config,
|
|
**kwargs)
|
|
model_config._frozen = True
|
|
return model_config
|
|
|
|
def get_bindings_model_config(self,
|
|
tokens_per_block: Optional[int] = None
|
|
) -> "ModelConfigCpp":
|
|
"""
|
|
This method is used to construct the bindings config for the model.
|
|
Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes
|
|
that an engine has been created.
|
|
|
|
Args:
|
|
tokens_per_block: The number of tokens per block. Please note that in PyTorch flow tokens_per_block is not available in the model config, instead it is defined in the executor config.
|
|
|
|
Returns:
|
|
The bindings model config.
|
|
"""
|
|
# TODO smor- this isn't robust, and currently tested for LlamaConfig only
|
|
# TODO smor- currently assuming no rnn layers, no MOE
|
|
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
|
|
|
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
|
|
# so downstream KV calculations see the full (non-partitioned) head count.
|
|
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
|
|
attn_cp_size = self.mapping.attn_cp_size
|
|
|
|
num_heads = self.pretrained_config.num_attention_heads // (
|
|
attn_tp_size * attn_cp_size)
|
|
|
|
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
|
|
|
|
model_config_cpp = ModelConfigCpp(
|
|
vocab_size=self.pretrained_config.vocab_size,
|
|
num_layers=self.pretrained_config.num_hidden_layers,
|
|
num_attention_layers=self.get_num_attention_layers(),
|
|
num_rnn_layers=0,
|
|
num_heads=num_heads,
|
|
hidden_size=hidden_size,
|
|
data_type=torch_dtype_to_binding(
|
|
self.pretrained_config.torch_dtype))
|
|
|
|
# For kv cache size calculation: set tokens_per_block
|
|
if tokens_per_block is None:
|
|
logger.warning(
|
|
f"tokens_per_block is not set, using default value {model_config_cpp.tokens_per_block}"
|
|
)
|
|
else:
|
|
model_config_cpp.tokens_per_block = tokens_per_block
|
|
|
|
num_key_value_heads = getattr(self.pretrained_config,
|
|
"num_key_value_heads", num_heads)
|
|
if isinstance(num_key_value_heads, (list, tuple)):
|
|
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
|
|
num_kv_heads_per_layer = [
|
|
kv_heads // (attn_tp_size * attn_cp_size)
|
|
for kv_heads in num_key_value_heads
|
|
]
|
|
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
|
|
else:
|
|
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
|
|
model_config_cpp.set_num_kv_heads(num_kv_heads)
|
|
|
|
mlp_hidden_size = None
|
|
if self.pretrained_config.intermediate_size is not None:
|
|
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
|
|
else:
|
|
# TODO: once tensorrt_llm._torch.AutoConfig is implemented, the following logic
|
|
# should be moved to tensorrt_llm._torch.AutoConfig of the relevant modeling_xxx file
|
|
if hasattr(self.pretrained_config, "architectures"
|
|
) and self.pretrained_config.architectures is not None:
|
|
architectures = self.pretrained_config.architectures
|
|
if len(architectures
|
|
) == 1 and architectures[0] == "DeciLMForCausalLM":
|
|
mlp_hidden_size = self._infer_nemotron_ffn_mult(
|
|
) // self.mapping.tp_size
|
|
else:
|
|
raise ValueError(
|
|
f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet"
|
|
)
|
|
if mlp_hidden_size is None:
|
|
raise ValueError(
|
|
f"Failed to infer mlp hidden size for model: {self.pretrained_config.model_type}"
|
|
)
|
|
|
|
# For kv cache size calculation: set size_per_head
|
|
head_dim_names = ["head_size", "head_dim"]
|
|
head_size = None
|
|
for head_dim_name in head_dim_names:
|
|
if hasattr(self.pretrained_config, head_dim_name):
|
|
value = getattr(self.pretrained_config, head_dim_name)
|
|
if value is not None:
|
|
head_size = value
|
|
break
|
|
|
|
if head_size is None:
|
|
assert hidden_size % num_heads == 0, (
|
|
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
|
|
)
|
|
calculated_head_size = hidden_size // num_heads
|
|
logger.warning(
|
|
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
|
|
)
|
|
head_size = calculated_head_size
|
|
|
|
model_config_cpp.mlp_hidden_size = mlp_hidden_size
|
|
model_config_cpp.size_per_head = head_size
|
|
|
|
# NOTE: this method is not robust, for Gemma3ForCausalLM only
|
|
layer_types = self.get_layer_types()
|
|
if layer_types is not None:
|
|
model_config_cpp.layer_types = layer_types
|
|
|
|
return model_config_cpp
|
|
|
|
def _infer_nemotron_ffn_mult(self):
|
|
# TODO smor: this is a hack to support Nemotron-Super-49B-v1 with LoRA, tracked by TRTLLM-5045 ticket
|
|
# Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum
|
|
# so that we don't set a too small mlp_hidden_size. This solution leads to a memory
|
|
# consumption that is higher than required.
|
|
biggest_ffn_mult = max([
|
|
(x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0)
|
|
for x in self.pretrained_config.block_configs
|
|
])
|
|
|
|
from tensorrt_llm._torch.models.modeling_nemotron_nas import \
|
|
_ffn_mult_to_intermediate_size
|
|
mlp_hidden_size = _ffn_mult_to_intermediate_size(
|
|
biggest_ffn_mult, self.pretrained_config.hidden_size)
|
|
|
|
return mlp_hidden_size
|
|
|
|
def get_layer_types(self) -> Optional[List[LayerTypeCpp]]:
|
|
"""
|
|
This method is a hack to support the effort to switch to KvCacheManagerCpp.
|
|
Currently, it is only tested for Gemma3ForCausalLM. For other models, it will return None.
|
|
"""
|
|
if self.pretrained_config.architectures[0] in ["Gemma3ForCausalLM"]:
|
|
logger.debug(
|
|
f"Setting layer types for {self.pretrained_config.architectures}"
|
|
)
|
|
return [
|
|
LayerTypeCpp.ATTENTION,
|
|
] * self.pretrained_config.num_hidden_layers
|
|
else:
|
|
return None
|
|
|
|
def get_num_attention_layers(self):
|
|
if is_nemotron_hybrid(self.pretrained_config):
|
|
return self.pretrained_config.hybrid_override_pattern.count("*")
|
|
elif hasattr(
|
|
self.pretrained_config, "architectures"
|
|
) and self.pretrained_config.architectures is not None and self.pretrained_config.architectures[
|
|
0] in ["Qwen3NextForCausalLM"]:
|
|
# Qwen3NextForCausalLM has hybrid attention pattern(1:3 full attention:linear attention),
|
|
# we need to calculate the number of fullattention layers
|
|
return self.pretrained_config.num_hidden_layers // self.pretrained_config.full_attention_interval
|
|
else:
|
|
return self.pretrained_config.num_hidden_layers
|