From 61f40c0159e8a703b0a87176a44d13ac0879462a Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Sun, 28 Dec 2025 01:03:24 -0800 Subject: [PATCH 1/7] Support model_kwargs for torch backend Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 53 ++++++++++++++++++- .../_torch/pyexecutor/model_loader.py | 7 ++- tensorrt_llm/llmapi/llm_args.py | 10 ++++ 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 459034ec0e..877c0ff594 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -4,7 +4,7 @@ import os import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, Generic, List, Optional, TypeVar import filelock import torch @@ -452,6 +452,57 @@ class ModelConfig(Generic[TConfig]): # Some checkpoints lack torch_dtype, populate with dtype pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype', None) + + # Apply model_kwargs to override config parameters if provided + model_kwargs = kwargs.pop('model_kwargs', None) + if model_kwargs: + from tensorrt_llm.logger import logger + + def _recursive_update_config(config: transformers.PretrainedConfig, + update_dict: Dict[str, Any]): + """ + Recursively update a PretrainedConfig object with values from update_dict. + Args: + config: PretrainedConfig object to update + update_dict: Dictionary with values to update in the config + """ + for key, value_new in update_dict.items(): + if not hasattr(config, key): + logger.warning( + f"model_kwargs key '{key}' not found in pretrained_config, ignoring." + ) + continue + + target_value = getattr(config, key) + + # Handle nested PretrainedConfig objects when value is a dict + if isinstance(value_new, dict) and isinstance( + target_value, transformers.PretrainedConfig): + # Recursively update the nested config + logger.info( + f"Recursively updating nested config: {key}") + _recursive_update_config(target_value, value_new) + elif (key in ["torch_dtype", "dtype"] + and isinstance(value_new, str) + and value_new != "auto"): + # check special handling of torch_dtype (DEPRECATED!) and dtype keys to ensure we + # use the correct torch.dtype object instead of a string. + dtype = getattr(torch, value_new) + assert isinstance(dtype, + torch.dtype), f"Invalid {dtype=}" + setattr(config, key, dtype) + logger.info( + f"Applied model_kwargs: {key}={dtype} (previous value: {target_value})" + ) + else: + # Direct update for simple values + setattr(config, key, value_new) + logger.info( + f"Applied model_kwargs: {key}={value_new} (previous value: {target_value})" + ) + + _recursive_update_config(pretrained_config, model_kwargs) + quant_config = QuantConfig() layer_quant_config = None moe_backend = kwargs.get('moe_backend', 'CUTLASS') diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 4756e24d08..2726a6343c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -361,7 +361,8 @@ class ModelLoader: use_low_precision_moe_combine=self.llm_args.moe_config. use_low_precision_moe_combine, nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config. - allowed_backends) + allowed_backends, + model_kwargs=self.llm_args.model_kwargs) # Store nvfp4 config in extra_attrs for Linear layer access config.extra_attrs[ @@ -373,9 +374,13 @@ class ModelLoader: config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype) # Allow overriding the number of layers via environment variable + # Note: This is kept for backward compatibility, but model_kwargs is preferred num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) if num_layers_override > 0: + logger.warning( + f"TLLM_OVERRIDE_LAYER_NUM is deprecated. Use model_kwargs instead: " + f"model_kwargs={{'num_hidden_layers': {num_layers_override}}}") config.pretrained_config.num_hidden_layers = num_layers_override for sub_config in ["text_config", "vision_config"]: if hasattr(config.pretrained_config, sub_config): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3f15252b84..54998b42fb 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1907,6 +1907,16 @@ class BaseLlmArgs(StrictBaseModel): # Below are all remaining arguments + model_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description= + "Extra kwargs for the model config class to customize the model config. " + "These arguments take precedence over default values or config values in the model config " + "file. Arguments are resolved in order: 1) Default values in model config class, 2) Values " + "in model config file, 3) Values in model_kwargs. Note: if a kwarg doesn't exist in the " + "model config class, it will be ignored.", + status="beta") + pipeline_parallel_size: int = Field( default=1, description="The pipeline parallel size.") From 769587f5de31df2c8a03ff825f3c7f6fb2fc6dd0 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Fri, 2 Jan 2026 13:36:25 -0800 Subject: [PATCH 2/7] Applied review comment Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tensorrt_llm/llmapi/llm_args.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 54998b42fb..780f1840de 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1909,12 +1909,9 @@ class BaseLlmArgs(StrictBaseModel): model_kwargs: Dict[str, Any] = Field( default_factory=dict, - description= - "Extra kwargs for the model config class to customize the model config. " - "These arguments take precedence over default values or config values in the model config " - "file. Arguments are resolved in order: 1) Default values in model config class, 2) Values " - "in model config file, 3) Values in model_kwargs. Note: if a kwarg doesn't exist in the " - "model config class, it will be ignored.", + description="Optional parameters overriding model config defaults. " + "Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. " + "Unknown keys are ignored", status="beta") pipeline_parallel_size: int = Field( From 144a903896b7f2168836cdaa31d64272e00f13ba Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Fri, 2 Jan 2026 14:45:04 -0800 Subject: [PATCH 3/7] Added unittest Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_args.py | 35 ++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 55c6c7b055..8b91a6131c 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -138,6 +138,30 @@ max_seq_len: 128 assert llm_args.max_num_tokens == 256 assert llm_args.max_seq_len == 128 + def test_llm_args_with_model_kwargs_trt(self): + yaml_content = """ +model_kwargs: + num_hidden_layers: 2 + """ + dict_content = self._yaml_to_dict(yaml_content) + llm_args = TrtLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), + dict_content) + llm_args = TrtLlmArgs(**llm_args_dict) + assert llm_args.model_kwargs['num_hidden_layers'] == 2 + + def test_llm_args_with_model_kwargs_pt(self): + yaml_content = """ +model_kwargs: + num_hidden_layers: 2 + """ + dict_content = self._yaml_to_dict(yaml_content) + llm_args = TorchLlmArgs(model=llama_model_path) + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), + dict_content) + llm_args = TorchLlmArgs(**llm_args_dict) + assert llm_args.model_kwargs['num_hidden_layers'] == 2 + def check_defaults(py_config_cls, pybind_config_cls): py_config = py_config_cls() @@ -445,6 +469,17 @@ class TestTorchLlmArgs: args = TorchLlmArgs(model=llama_model_path) args.invalid_arg = 1 + @print_traceback_on_error + def test_model_kwargs_with_num_hidden_layers(self): + """Test that model_kwargs can override num_hidden_layers.""" + from tensorrt_llm._torch.model_config import ModelConfig + + model_kwargs = {'num_hidden_layers': 2} + + config = ModelConfig.from_pretrained(llama_model_path, + model_kwargs=model_kwargs) + assert config.pretrained_config.num_hidden_layers == 2 + class TestTrtLlmArgs: From 25524bed2f6ff24e359ab55d9d08db02ce1a817a Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Fri, 2 Jan 2026 20:51:52 -0800 Subject: [PATCH 4/7] Fix CI failure due to unregistered reference parameter for LLM API Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tests/unittest/api_stability/references/llm.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 4b6f8cedab..f7062d5718 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -227,6 +227,10 @@ methods: annotation: Optional[Dict[str, str]] default: null status: prototype + model_kwargs: + annotation: object + default: null + status: beta return_annotation: None generate: parameters: From 1497113a051b891179a995914940da2322a67f23 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Tue, 6 Jan 2026 00:40:08 -0800 Subject: [PATCH 5/7] Applied review comment Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 1 - tensorrt_llm/llmapi/llm_args.py | 2 +- .../api_stability/references/llm.yaml | 4 +-- tests/unittest/llmapi/test_llm_args.py | 30 +++++++------------ 4 files changed, 13 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 877c0ff594..1629f975df 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -456,7 +456,6 @@ class ModelConfig(Generic[TConfig]): # Apply model_kwargs to override config parameters if provided model_kwargs = kwargs.pop('model_kwargs', None) if model_kwargs: - from tensorrt_llm.logger import logger def _recursive_update_config(config: transformers.PretrainedConfig, update_dict: Dict[str, Any]): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 780f1840de..fa34bdda0d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1912,7 +1912,7 @@ class BaseLlmArgs(StrictBaseModel): description="Optional parameters overriding model config defaults. " "Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. " "Unknown keys are ignored", - status="beta") + status="prototype") pipeline_parallel_size: int = Field( default=1, description="The pipeline parallel size.") diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index f7062d5718..1c261c55fc 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -228,9 +228,9 @@ methods: default: null status: prototype model_kwargs: - annotation: object + annotation: Optional[Dict[str, Any]] default: null - status: beta + status: prototype return_annotation: None generate: parameters: diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 8b91a6131c..b5c74bc8e7 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -138,28 +138,17 @@ max_seq_len: 128 assert llm_args.max_num_tokens == 256 assert llm_args.max_seq_len == 128 - def test_llm_args_with_model_kwargs_trt(self): + @pytest.mark.parametrize("llm_args_cls", [TrtLlmArgs, TorchLlmArgs]) + def test_llm_args_with_model_kwargs(self, llm_args_cls): yaml_content = """ model_kwargs: num_hidden_layers: 2 """ dict_content = self._yaml_to_dict(yaml_content) - llm_args = TrtLlmArgs(model=llama_model_path) + llm_args = llm_args_cls(model=llama_model_path) llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) - llm_args = TrtLlmArgs(**llm_args_dict) - assert llm_args.model_kwargs['num_hidden_layers'] == 2 - - def test_llm_args_with_model_kwargs_pt(self): - yaml_content = """ -model_kwargs: - num_hidden_layers: 2 - """ - dict_content = self._yaml_to_dict(yaml_content) - llm_args = TorchLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), - dict_content) - llm_args = TorchLlmArgs(**llm_args_dict) + llm_args = llm_args_cls(**llm_args_dict) assert llm_args.model_kwargs['num_hidden_layers'] == 2 @@ -473,12 +462,13 @@ class TestTorchLlmArgs: def test_model_kwargs_with_num_hidden_layers(self): """Test that model_kwargs can override num_hidden_layers.""" from tensorrt_llm._torch.model_config import ModelConfig - + config_no_kwargs = ModelConfig.from_pretrained( + llama_model_path).pretrained_config model_kwargs = {'num_hidden_layers': 2} - - config = ModelConfig.from_pretrained(llama_model_path, - model_kwargs=model_kwargs) - assert config.pretrained_config.num_hidden_layers == 2 + config_with_kwargs = ModelConfig.from_pretrained( + llama_model_path, model_kwargs=model_kwargs).pretrained_config + assert config_no_kwargs.num_hidden_layers != config_with_kwargs.num_hidden_layers + assert config_with_kwargs.num_hidden_layers == 2 class TestTrtLlmArgs: From 4439e84db174e08b02fdf9b041ad683327f24401 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Wed, 7 Jan 2026 22:08:47 -0800 Subject: [PATCH 6/7] Fix unittest failure Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 2 +- tensorrt_llm/llmapi/llm_args.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index f20c52babf..bf5384af15 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -121,7 +121,7 @@ class AutoModelForCausalLMFactory(AutoModelFactory): self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs) self.model_kwargs = deep_merge_dicts( self._model_defaults, - self.model_kwargs, + self.model_kwargs or {}, ) # set sharding config source to huggingface diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index fa34bdda0d..1baae2a8eb 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1907,8 +1907,8 @@ class BaseLlmArgs(StrictBaseModel): # Below are all remaining arguments - model_kwargs: Dict[str, Any] = Field( - default_factory=dict, + model_kwargs: Optional[Dict[str, Any]] = Field( + default=None, description="Optional parameters overriding model config defaults. " "Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. " "Unknown keys are ignored", From be0949bcabda6dda84b164b146e769b28c525607 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Sat, 10 Jan 2026 01:39:23 -0800 Subject: [PATCH 7/7] Fix unittest failure - model_kwargs is an optional parameter, and it is not always set. - This is handled in the standard ConfigLoader, but not in custom ConfigLoader such as MistralConfigLoader - In current update, let model_kwargs be supported only by ConfigLoader and added warning in the custom loader Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- .../models/checkpoints/mistral/config_loader.py | 10 ++++++++++ tensorrt_llm/_torch/pyexecutor/model_loader.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py index c679734fcf..dbc7abf730 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -8,6 +8,7 @@ from transformers import PretrainedConfig, WhisperConfig from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader from tensorrt_llm._torch.models.modeling_utils import register_config_loader +from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -327,6 +328,15 @@ class MistralConfigLoader(BaseConfigLoader): block_size = (128, 128) quant_config.group_size = block_size[0] + # model_kwargs is not supported for Mistral format checkpoints + # Extract it from kwargs to avoid passing to ModelConfig.__init__ (which doesn't accept it) + model_kwargs = kwargs.pop("model_kwargs", None) + if model_kwargs: + logger.warning( + "model_kwargs is not supported for Mistral format checkpoints. " + f"Ignoring model_kwargs: {model_kwargs}" + ) + kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter model_config = ModelConfig( pretrained_config=pretrained_config, diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 2726a6343c..1d24d24924 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -338,8 +338,8 @@ class ModelLoader: self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: """Loads and validates the model configuration.""" - config = checkpoint_loader.load_config( - checkpoint_dir, + load_config_kwargs = dict( + checkpoint_dir=checkpoint_dir, trust_remote_code=True, mapping=self.mapping, enable_min_latency=self.llm_args.enable_min_latency, @@ -361,8 +361,13 @@ class ModelLoader: use_low_precision_moe_combine=self.llm_args.moe_config. use_low_precision_moe_combine, nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config. - allowed_backends, - model_kwargs=self.llm_args.model_kwargs) + allowed_backends) + + # Only pass model_kwargs if it's explicitly set (not None) + if self.llm_args.model_kwargs is not None: + load_config_kwargs['model_kwargs'] = self.llm_args.model_kwargs + + config = checkpoint_loader.load_config(**load_config_kwargs) # Store nvfp4 config in extra_attrs for Linear layer access config.extra_attrs[