diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index f20c52babf..bf5384af15 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -121,7 +121,7 @@ class AutoModelForCausalLMFactory(AutoModelFactory): self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs) self.model_kwargs = deep_merge_dicts( self._model_defaults, - self.model_kwargs, + self.model_kwargs or {}, ) # set sharding config source to huggingface diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 2fff7e47eb..19ed1f3510 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -4,7 +4,7 @@ import os import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, Generic, List, Optional, TypeVar import filelock import torch @@ -470,6 +470,56 @@ class ModelConfig(Generic[TConfig]): # Some checkpoints lack torch_dtype, populate with dtype pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype', None) + + # Apply model_kwargs to override config parameters if provided + model_kwargs = kwargs.pop('model_kwargs', None) + if model_kwargs: + + def _recursive_update_config(config: transformers.PretrainedConfig, + update_dict: Dict[str, Any]): + """ + Recursively update a PretrainedConfig object with values from update_dict. + Args: + config: PretrainedConfig object to update + update_dict: Dictionary with values to update in the config + """ + for key, value_new in update_dict.items(): + if not hasattr(config, key): + logger.warning( + f"model_kwargs key '{key}' not found in pretrained_config, ignoring." + ) + continue + + target_value = getattr(config, key) + + # Handle nested PretrainedConfig objects when value is a dict + if isinstance(value_new, dict) and isinstance( + target_value, transformers.PretrainedConfig): + # Recursively update the nested config + logger.info( + f"Recursively updating nested config: {key}") + _recursive_update_config(target_value, value_new) + elif (key in ["torch_dtype", "dtype"] + and isinstance(value_new, str) + and value_new != "auto"): + # check special handling of torch_dtype (DEPRECATED!) and dtype keys to ensure we + # use the correct torch.dtype object instead of a string. + dtype = getattr(torch, value_new) + assert isinstance(dtype, + torch.dtype), f"Invalid {dtype=}" + setattr(config, key, dtype) + logger.info( + f"Applied model_kwargs: {key}={dtype} (previous value: {target_value})" + ) + else: + # Direct update for simple values + setattr(config, key, value_new) + logger.info( + f"Applied model_kwargs: {key}={value_new} (previous value: {target_value})" + ) + + _recursive_update_config(pretrained_config, model_kwargs) + quant_config = QuantConfig() layer_quant_config = None moe_backend = kwargs.get('moe_backend', 'CUTLASS') diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py index c679734fcf..dbc7abf730 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig, WhisperConfig from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader from tensorrt_llm._torch.models.modeling_utils import register_config_loader +from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -327,6 +328,15 @@ class MistralConfigLoader(BaseConfigLoader): block_size = (128, 128) quant_config.group_size = block_size[0] + # model_kwargs is not supported for Mistral format checkpoints + # Extract it from kwargs to avoid passing to ModelConfig.__init__ (which doesn't accept it) + model_kwargs = kwargs.pop("model_kwargs", None) + if model_kwargs: + logger.warning( + "model_kwargs is not supported for Mistral format checkpoints. " + f"Ignoring model_kwargs: {model_kwargs}" + ) + kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter model_config = ModelConfig( pretrained_config=pretrained_config, diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 4756e24d08..1d24d24924 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -338,8 +338,8 @@ class ModelLoader: self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: """Loads and validates the model configuration.""" - config = checkpoint_loader.load_config( - checkpoint_dir, + load_config_kwargs = dict( + checkpoint_dir=checkpoint_dir, trust_remote_code=True, mapping=self.mapping, enable_min_latency=self.llm_args.enable_min_latency, @@ -363,6 +363,12 @@ class ModelLoader: nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config. allowed_backends) + # Only pass model_kwargs if it's explicitly set (not None) + if self.llm_args.model_kwargs is not None: + load_config_kwargs['model_kwargs'] = self.llm_args.model_kwargs + + config = checkpoint_loader.load_config(**load_config_kwargs) + # Store nvfp4 config in extra_attrs for Linear layer access config.extra_attrs[ 'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends @@ -373,9 +379,13 @@ class ModelLoader: config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype) # Allow overriding the number of layers via environment variable + # Note: This is kept for backward compatibility, but model_kwargs is preferred num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) if num_layers_override > 0: + logger.warning( + f"TLLM_OVERRIDE_LAYER_NUM is deprecated. Use model_kwargs instead: " + f"model_kwargs={{'num_hidden_layers': {num_layers_override}}}") config.pretrained_config.num_hidden_layers = num_layers_override for sub_config in ["text_config", "vision_config"]: if hasattr(config.pretrained_config, sub_config): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3f15252b84..1baae2a8eb 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1907,6 +1907,13 @@ class BaseLlmArgs(StrictBaseModel): # Below are all remaining arguments + model_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description="Optional parameters overriding model config defaults. " + "Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. " + "Unknown keys are ignored", + status="prototype") + pipeline_parallel_size: int = Field( default=1, description="The pipeline parallel size.") diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 4b6f8cedab..1c261c55fc 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -227,6 +227,10 @@ methods: annotation: Optional[Dict[str, str]] default: null status: prototype + model_kwargs: + annotation: Optional[Dict[str, Any]] + default: null + status: prototype return_annotation: None generate: parameters: diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 55c6c7b055..b5c74bc8e7 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -138,6 +138,19 @@ max_seq_len: 128 assert llm_args.max_num_tokens == 256 assert llm_args.max_seq_len == 128 + @pytest.mark.parametrize("llm_args_cls", [TrtLlmArgs, TorchLlmArgs]) + def test_llm_args_with_model_kwargs(self, llm_args_cls): + yaml_content = """ +model_kwargs: + num_hidden_layers: 2 + """ + dict_content = self._yaml_to_dict(yaml_content) + llm_args = llm_args_cls(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), + dict_content) + llm_args = llm_args_cls(**llm_args_dict) + assert llm_args.model_kwargs['num_hidden_layers'] == 2 + def check_defaults(py_config_cls, pybind_config_cls): py_config = py_config_cls() @@ -445,6 +458,18 @@ class TestTorchLlmArgs: args = TorchLlmArgs(model=llama_model_path) args.invalid_arg = 1 + @print_traceback_on_error + def test_model_kwargs_with_num_hidden_layers(self): + """Test that model_kwargs can override num_hidden_layers.""" + from tensorrt_llm._torch.model_config import ModelConfig + config_no_kwargs = ModelConfig.from_pretrained( + llama_model_path).pretrained_config + model_kwargs = {'num_hidden_layers': 2} + config_with_kwargs = ModelConfig.from_pretrained( + llama_model_path, model_kwargs=model_kwargs).pretrained_config + assert config_no_kwargs.num_hidden_layers != config_with_kwargs.num_hidden_layers + assert config_with_kwargs.num_hidden_layers == 2 + class TestTrtLlmArgs: