import contextlib import json import os import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Generic, List, Optional, TypeVar import filelock import torch import transformers from transformers.utils import HF_MODULES_CACHE from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid, load_pretrained_config) from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig, MoeLoadBalancerConfig) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @contextlib.contextmanager def config_file_lock(timeout: int = 10): """ Context manager for file locking when loading pretrained configs. This prevents race conditions when multiple processes try to download/load the same model configuration simultaneously. Args: timeout: Maximum time to wait for lock acquisition in seconds """ # Use a single global lock file in HF cache directory # This serializes all model loading operations to prevent race conditions lock_path = Path(HF_MODULES_CACHE) / "_remote_code.lock" # Create and acquire the lock lock = filelock.FileLock(str(lock_path), timeout=timeout) try: with lock: yield except (PermissionError, filelock.Timeout): # Fallback to tempdir tmp_dir = Path(tempfile.gettempdir()) tmp_dir.mkdir(parents=True, exist_ok=True) tmp_lock_path = tmp_dir / "_remote_code.lock" tmp_lock = filelock.FileLock(str(tmp_lock_path), timeout=timeout) try: with tmp_lock: yield except filelock.Timeout: logger.warning( f"failed to acquire tempdir config lock within {timeout} seconds, proceeding without lock" ) # proceed without lock yield except (PermissionError) as e: logger.warning( f"tempdir config lock unavailable due to OS/permission issue: {e}, proceeding without lock" ) # proceed without lock yield @dataclass(kw_only=True) class ModelConfig(Generic[TConfig]): pretrained_config: Optional[TConfig] = None mapping: Mapping = field(default_factory=Mapping) # Quantization configs quant_config: QuantConfig = field(default_factory=QuantConfig) # Per linear layer quantization in quant_cfg.json or hf_quant_config.json quant_config_dict: Optional[Dict[str, QuantConfig]] = None # Delay weights creation to DecoderModelForCausalLM.__post_init__ # to support mixed quantization. skip_create_weights_in_init: bool = False spec_config: Optional["DecodingBaseConfig"] = None lora_config: Optional["LoraConfig"] = None sparse_attention_config: Optional["SparseAttentionConfig"] = None is_generation: bool = True max_num_tokens: int = 8192 max_seq_len: Optional[int] = None moe_max_num_tokens: Optional[int] = None moe_load_balancer: Optional[MoeLoadBalancerConfig] = None attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM # IF true, disables FC2+finalize fusion in CUTLASS MoE backend moe_disable_finalize_fusion: bool = False # If true, use low precision combine in MoE operations (only for NVFP4 quantization) use_low_precision_moe_combine: bool = False allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO # If true, enable min-latency mode. Currently only used for Llama4. enable_min_latency: bool = False # Allow models to select op according to whether CUDA Graphs are used. use_cuda_graph: bool = False force_dynamic_quantization: bool = False # If true, use torch.compile for embedding layers. enable_torch_compile_for_embedding = False extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) _frozen: bool = field(default=False, init=False, repr=False) # If true, ONLY the vision encoder part of the full model is loaded/executed. mm_encoder_only: bool = False def __setattr__(self, key, value): """ Prevent modification of frozen instance attributes. However, we allow modification of 'extra_attrs' attributes for torch.compile and 'pretrained_config' attributes for mutimodal models. 'quant_config' is allowed to be modified to set different quantization for VLM. All the other attributes are frozen. This can be bypassed by manually setting '_frozen' to False. The design is to discourage modifying the attributes unintentionally. """ if self._frozen: if key not in ('_frozen', 'extra_attrs', 'pretrained_config', 'quant_config'): raise AttributeError( f"Cannot modify ModelConfig.'{key}' - instance is frozen") super().__setattr__(key, value) def __post_init__(self): if self.pretrained_config and hasattr(self.pretrained_config, "architectures"): self.is_generation = self.is_generation_model( self.pretrained_config.architectures, mm_encoder_only=self.mm_encoder_only) def get_all_reduce_strategy(strategy: str = "AUTO"): maps = { "AUTO": AllReduceStrategy.AUTO, "NCCL": AllReduceStrategy.NCCL, "UB": AllReduceStrategy.UB, "MINLATENCY": AllReduceStrategy.MIN_LATENCY, "ONESHOT": AllReduceStrategy.ONESHOT, "TWOSHOT": AllReduceStrategy.TWOSHOT, "LOWPRECISION": AllReduceStrategy.LOWPRECISION, "MNNVL": AllReduceStrategy.MNNVL, "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC } key = strategy.upper() return maps[key] if key in maps else AllReduceStrategy.AUTO if isinstance(self.allreduce_strategy, str): self.allreduce_strategy = get_all_reduce_strategy( self.allreduce_strategy) @property def torch_dtype(self) -> torch.dtype: """Get the torch dtype of the model.""" # TODO: this is an assumption that a HF model is always in bfloat16 # We should figure out a better way to handle this if other models # start to not report dtype. return self.pretrained_config.torch_dtype or torch.bfloat16 @property def fuse_pos_embd(self): if self.attn_backend == 'TRTLLM': return True elif self.attn_backend == 'FLASHINFER': return False return False @property def enable_flash_mla(self): if self.attn_backend == 'TRTLLM': if getattr(self.pretrained_config, "kv_lora_rank", None) and getattr( self.pretrained_config, "qk_rope_head_dim", None): head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim if head_dim == 576 and torch.cuda.get_device_capability() == ( 9, 0): return True return False def get_quant_config(self, name: Optional[str] = None) -> QuantConfig: if name is None or self.per_layer_quant_configs is None: return self.quant_config if name in self.per_layer_quant_configs: return self.per_layer_quant_configs[name] raise ValueError(f'quant config of {name} is not found') @staticmethod def is_generation_model(model_architectures: Optional[List[str]], mm_encoder_only: bool = False) -> bool: if model_architectures is None: logger.warning( "Model architectures is None, default to is_generation_model=True" ) return True if mm_encoder_only: return False return model_architectures[0] not in [ "BertForSequenceClassification", "Qwen2ForProcessRewardModel", "Qwen2ForRewardModel", "LlamaForTextEmbedding" ] # TODO: should be 'not model_type == ModelType.ENCODER_ONLY' # once ModelType is used in pytorch flow. @staticmethod def load_modelopt_quant_config(quant_config_file, checkpoint_dir, moe_backend): quant_config = QuantConfig() layer_quant_config = None with open(quant_config_file) as f: quant_config_dict = json.load(f) json_quant_configs = quant_config_dict['quantization'] quant_config.quant_algo = json_quant_configs.get('quant_algo', None) # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES if quant_config.quant_algo == "fp8_pb_wo": quant_config.quant_algo = 'FP8_BLOCK_SCALES' quant_config.kv_cache_quant_algo = json_quant_configs.get( 'kv_cache_quant_algo', None) quant_config.group_size = json_quant_configs.get('group_size', None) quant_config.exclude_modules = json_quant_configs.get( 'exclude_modules', None) if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: json_extended_quant_configs: dict = {} # See tests/unittest/llmapi/test_llm_quant.py try: mixed_quant_config_file = transformers.utils.hub.cached_file( checkpoint_dir, 'quant_cfg.json') with open(mixed_quant_config_file) as fm: json_extended_quant_configs = json.load(fm) except Exception: logger.info( f"No quant_cfg.json found for layer quant info, using hf_quant_config.json." ) json_quant_configs.update(json_extended_quant_configs) # kv_cache_quant_algo is global regardless of MIXED_PRECISION kv_cache_quant_algo = json_quant_configs.get( 'kv_cache_quant_algo', None) mixed_quant_configs = json_quant_configs.get( 'quantized_layers', None) if (kv_quant_lhs := json_extended_quant_configs.get( "kv_cache_quant_algo", None)) is not None and ( kv_quant_rhs := quant_config.kv_cache_quant_algo) is not None: if kv_quant_lhs != kv_quant_rhs: raise RuntimeError( f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs}," f"is different from 'hf_quant_config.json', {kv_quant_rhs}!" ) quant_config.kv_cache_quant_algo = json_quant_configs[ "kv_cache_quant_algo"] for layer in mixed_quant_configs: config = QuantConfig() config.kv_cache_quant_algo = kv_cache_quant_algo config.quant_algo = mixed_quant_configs[layer]['quant_algo'] config.group_size = mixed_quant_configs[layer].get( 'group_size', None) mixed_quant_configs[layer] = config layer_quant_config = mixed_quant_configs elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: if quant_config.group_size is None: quant_config.group_size = 128 if moe_backend == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None: quant_config.exclude_modules = [ "*kv_b_proj*", "*k_b_proj*", "*eh_proj" ] return quant_config, layer_quant_config @staticmethod def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): quant_algo = ModelConfig.override_quant_algo() if quant_algo is None and not is_dynamic_quant: if get_sm_version() >= 100: if moe_backend == 'TRITON': return QuantAlgo.W4A8_MXFP4_FP8 else: return QuantAlgo.W4A8_MXFP4_MXFP8 else: return QuantAlgo.W4A16_MXFP4 else: return quant_algo @staticmethod def load_hf_quant_config(hf_quant_config, moe_backend): quant_config = QuantConfig() layer_quant_config = None # DeepSeek V3 FP8 ckpt if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get( "weight_block_size", []): quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES if moe_backend == 'TRTLLM': # TODO: This is a hack. Remove after fp8 bmm is integrated. quant_config.exclude_modules = [ "*kv_b_proj*", "*k_b_proj*", "*eh_proj" ] else: quant_config.exclude_modules = ["*eh_proj"] block_size = hf_quant_config.get("weight_block_size", []) assert tuple(block_size) == ( 128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" quant_config.group_size = block_size[0] # MXFP4 checkpoints. elif hf_quant_config.get("quant_method") == "mxfp4": quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( moe_backend) quant_config.group_size = 32 quant_config.exclude_modules = [ 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', 'embedding', 'unembedding' ] return quant_config, layer_quant_config @staticmethod def load_quant_config_from_dtypes_json(dtypes_json_file, moe_backend: str): quant_config = QuantConfig() layer_quant_config = None exclude_modules = set() has_mxfp4 = False is_dynamic_quant = False with open(dtypes_json_file) as f: dtypes_json = json.load(f) for layer, dtype in dtypes_json.items(): if layer.endswith("weight"): if dtype == "BF16" or dtype == "FP16": names = layer.split(".") exclude_modules.add('.'.join(names[:-1])) elif dtype == "MXFP4": # This is the path for the fp8 checkpoint which requires dynamic quantization. is_dynamic_quant = True has_mxfp4 = True elif layer.endswith("weight.blocks"): scale_name = layer.replace("weight.blocks", "weight.scales") scale_dtype = dtypes_json.get(scale_name, None) assert scale_dtype == "UE8" is_dynamic_quant = False has_mxfp4 = True if has_mxfp4: quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( moe_backend, is_dynamic_quant) quant_config.group_size = 32 quant_config.exclude_modules = list(exclude_modules) logger.info(f"Setting quant_config: {quant_config}") return quant_config, layer_quant_config @staticmethod def override_quant_algo(): new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None) supported_algos = { "W4A16_MXFP4": QuantAlgo.W4A16_MXFP4, "W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8, "W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8, } if new_algo is not None: if new_algo.upper() in supported_algos: return supported_algos[new_algo.upper()] else: logger.warning( f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}" ) return None @classmethod def from_pretrained(cls, checkpoint_dir: str, trust_remote_code=False, **kwargs): # Use file lock to prevent race conditions when multiple processes # try to import/cache the same remote model config file with config_file_lock(): # When handling the case where model_format is TLLM_ENGINE # send cyclic requests to the NONE URL. if checkpoint_dir is not None: pretrained_config = load_pretrained_config( checkpoint_dir, trust_remote_code=trust_remote_code, **kwargs, ) if pretrained_config.architectures[ 0] == "DeepseekV32ForCausalLM": sparse_attention_config = kwargs.get( 'sparse_attention_config') if sparse_attention_config: index_n_heads = sparse_attention_config.index_n_heads or pretrained_config.index_n_heads index_head_dim = sparse_attention_config.index_head_dim or pretrained_config.index_head_dim index_topk = sparse_attention_config.index_topk or pretrained_config.index_topk indexer_max_chunk_size = sparse_attention_config.indexer_max_chunk_size else: index_n_heads = pretrained_config.index_n_heads index_head_dim = pretrained_config.index_head_dim index_topk = pretrained_config.index_topk indexer_max_chunk_size = None kwargs[ 'sparse_attention_config'] = DeepSeekSparseAttentionConfig( index_n_heads=index_n_heads, index_head_dim=index_head_dim, index_topk=index_topk, indexer_max_chunk_size=indexer_max_chunk_size) else: raise ValueError( "checkpoint_dir is None. Cannot load model config without a valid checkpoint directory." ) # Get cached file from path or repo id, return None if not exists. def cached_file(path_or_repo_id, file_name): try: return transformers.utils.hub.cached_file( path_or_repo_id, file_name) except OSError: return None # Some checkpoints lack torch_dtype, populate with dtype pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype', None) quant_config = QuantConfig() layer_quant_config = None moe_backend = kwargs.get('moe_backend', 'CUTLASS') # quantized ckpt in modelopt format if quant_config_file := cached_file(checkpoint_dir, 'hf_quant_config.json'): quant_config, layer_quant_config = cls.load_modelopt_quant_config( quant_config_file, checkpoint_dir, moe_backend) # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): hf_quant_config = pretrained_config.quantization_config quant_config, layer_quant_config = cls.load_hf_quant_config( hf_quant_config, moe_backend) elif quant_config_file := cached_file(checkpoint_dir, 'dtypes.json'): quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json( quant_config_file, moe_backend) model_config = cls(pretrained_config=pretrained_config, quant_config=quant_config, quant_config_dict=layer_quant_config, **kwargs) model_config._frozen = True return model_config def get_bindings_model_config(self, tokens_per_block: Optional[int] = None ) -> "ModelConfigCpp": """ This method is used to construct the bindings config for the model. Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes that an engine has been created. Args: tokens_per_block: The number of tokens per block. Please note that in PyTorch flow tokens_per_block is not available in the model config, instead it is defined in the executor config. Returns: The bindings model config. """ # TODO smor- this isn't robust, and currently tested for LlamaConfig only # TODO smor- currently assuming no rnn layers, no MOE from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp num_heads = self.pretrained_config.num_attention_heads // ( self.mapping.tp_size * self.mapping.cp_size) hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size model_config_cpp = ModelConfigCpp( vocab_size=self.pretrained_config.vocab_size, num_layers=self.pretrained_config.num_hidden_layers, num_attention_layers=self.get_num_attention_layers(), num_rnn_layers=0, num_heads=num_heads, hidden_size=hidden_size, data_type=torch_dtype_to_binding( self.pretrained_config.torch_dtype)) # For kv cache size calculation: set tokens_per_block if tokens_per_block is None: logger.warning( f"tokens_per_block is not set, using default value {model_config_cpp.tokens_per_block}" ) else: model_config_cpp.tokens_per_block = tokens_per_block num_key_value_heads = getattr(self.pretrained_config, "num_key_value_heads", num_heads) if isinstance(num_key_value_heads, (list, tuple)): # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models) num_kv_heads_per_layer = [ kv_heads // (self.mapping.tp_size * self.mapping.cp_size) for kv_heads in num_key_value_heads ] model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer else: num_kv_heads = num_key_value_heads // (self.mapping.tp_size * self.mapping.cp_size) model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None if self.pretrained_config.intermediate_size is not None: mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size else: # TODO: once tensorrt_llm._torch.AutoConfig is implemented, the following logic # should be moved to tensorrt_llm._torch.AutoConfig of the relevant modeling_xxx file if hasattr(self.pretrained_config, "architectures" ) and self.pretrained_config.architectures is not None: architectures = self.pretrained_config.architectures if len(architectures ) == 1 and architectures[0] == "DeciLMForCausalLM": mlp_hidden_size = self._infer_nemotron_ffn_mult( ) // self.mapping.tp_size else: raise ValueError( f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet" ) if mlp_hidden_size is None: raise ValueError( f"Failed to infer mlp hidden size for model: {self.pretrained_config.model_type}" ) # For kv cache size calculation: set size_per_head head_dim_names = ["head_size", "head_dim"] head_size = None for head_dim_name in head_dim_names: if hasattr(self.pretrained_config, head_dim_name): value = getattr(self.pretrained_config, head_dim_name) if value is not None: head_size = value break if head_size is None: assert hidden_size % num_heads == 0, ( f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})" ) calculated_head_size = hidden_size // num_heads logger.warning( f"head_size/head_dim is not set or None, using default value {calculated_head_size}" ) head_size = calculated_head_size model_config_cpp.mlp_hidden_size = mlp_hidden_size model_config_cpp.size_per_head = head_size # NOTE: this method is not robust, for Gemma3ForCausalLM only layer_types = self.get_layer_types() if layer_types is not None: model_config_cpp.layer_types = layer_types return model_config_cpp def _infer_nemotron_ffn_mult(self): # TODO smor: this is a hack to support Nemotron-Super-49B-v1 with LoRA, tracked by TRTLLM-5045 ticket # Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum # so that we don't set a too small mlp_hidden_size. This solution leads to a memory # consumption that is higher than required. biggest_ffn_mult = max([ (x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0) for x in self.pretrained_config.block_configs ]) from tensorrt_llm._torch.models.modeling_nemotron_nas import \ _ffn_mult_to_intermediate_size mlp_hidden_size = _ffn_mult_to_intermediate_size( biggest_ffn_mult, self.pretrained_config.hidden_size) return mlp_hidden_size def get_layer_types(self) -> Optional[List[LayerTypeCpp]]: """ This method is a hack to support the effort to switch to KvCacheManagerCpp. Currently, it is only tested for Gemma3ForCausalLM. For other models, it will return None. """ if self.pretrained_config.architectures[0] in ["Gemma3ForCausalLM"]: logger.debug( f"Setting layer types for {self.pretrained_config.architectures}" ) return [ LayerTypeCpp.ATTENTION, ] * self.pretrained_config.num_hidden_layers else: return None def get_num_attention_layers(self): if is_nemotron_hybrid(self.pretrained_config): return self.pretrained_config.hybrid_override_pattern.count("*") elif hasattr( self.pretrained_config, "architectures" ) and self.pretrained_config.architectures is not None and self.pretrained_config.architectures[ 0] in ["Qwen3NextForCausalLM"]: # Qwen3NextForCausalLM has hybrid attention pattern(1:3 full attention:linear attention), # we need to calculate the number of fullattention layers return self.pretrained_config.num_hidden_layers // self.pretrained_config.full_attention_interval else: return self.pretrained_config.num_hidden_layers