This commit is contained in:
Taylor Yeonbok Lee 2026-01-13 21:25:09 +08:00 committed by GitHub
commit c1f868f9ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 110 additions and 4 deletions

View File

@ -121,7 +121,7 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
self.model_kwargs = deep_merge_dicts(
self._model_defaults,
self.model_kwargs,
self.model_kwargs or {},
)
# set sharding config source to huggingface

View File

@ -4,7 +4,7 @@ import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, TypeVar
import filelock
import torch
@ -470,6 +470,56 @@ class ModelConfig(Generic[TConfig]):
# Some checkpoints lack torch_dtype, populate with dtype
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
None)
# Apply model_kwargs to override config parameters if provided
model_kwargs = kwargs.pop('model_kwargs', None)
if model_kwargs:
def _recursive_update_config(config: transformers.PretrainedConfig,
update_dict: Dict[str, Any]):
"""
Recursively update a PretrainedConfig object with values from update_dict.
Args:
config: PretrainedConfig object to update
update_dict: Dictionary with values to update in the config
"""
for key, value_new in update_dict.items():
if not hasattr(config, key):
logger.warning(
f"model_kwargs key '{key}' not found in pretrained_config, ignoring."
)
continue
target_value = getattr(config, key)
# Handle nested PretrainedConfig objects when value is a dict
if isinstance(value_new, dict) and isinstance(
target_value, transformers.PretrainedConfig):
# Recursively update the nested config
logger.info(
f"Recursively updating nested config: {key}")
_recursive_update_config(target_value, value_new)
elif (key in ["torch_dtype", "dtype"]
and isinstance(value_new, str)
and value_new != "auto"):
# check special handling of torch_dtype (DEPRECATED!) and dtype keys to ensure we
# use the correct torch.dtype object instead of a string.
dtype = getattr(torch, value_new)
assert isinstance(dtype,
torch.dtype), f"Invalid {dtype=}"
setattr(config, key, dtype)
logger.info(
f"Applied model_kwargs: {key}={dtype} (previous value: {target_value})"
)
else:
# Direct update for simple values
setattr(config, key, value_new)
logger.info(
f"Applied model_kwargs: {key}={value_new} (previous value: {target_value})"
)
_recursive_update_config(pretrained_config, model_kwargs)
quant_config = QuantConfig()
layer_quant_config = None
moe_backend = kwargs.get('moe_backend', 'CUTLASS')

View File

@ -8,6 +8,7 @@ from transformers import PretrainedConfig, WhisperConfig
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
@ -327,6 +328,15 @@ class MistralConfigLoader(BaseConfigLoader):
block_size = (128, 128)
quant_config.group_size = block_size[0]
# model_kwargs is not supported for Mistral format checkpoints
# Extract it from kwargs to avoid passing to ModelConfig.__init__ (which doesn't accept it)
model_kwargs = kwargs.pop("model_kwargs", None)
if model_kwargs:
logger.warning(
"model_kwargs is not supported for Mistral format checkpoints. "
f"Ignoring model_kwargs: {model_kwargs}"
)
kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter
model_config = ModelConfig(
pretrained_config=pretrained_config,

View File

@ -338,8 +338,8 @@ class ModelLoader:
self, checkpoint_dir: str,
checkpoint_loader: BaseCheckpointLoader) -> ModelConfig:
"""Loads and validates the model configuration."""
config = checkpoint_loader.load_config(
checkpoint_dir,
load_config_kwargs = dict(
checkpoint_dir=checkpoint_dir,
trust_remote_code=True,
mapping=self.mapping,
enable_min_latency=self.llm_args.enable_min_latency,
@ -363,6 +363,12 @@ class ModelLoader:
nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config.
allowed_backends)
# Only pass model_kwargs if it's explicitly set (not None)
if self.llm_args.model_kwargs is not None:
load_config_kwargs['model_kwargs'] = self.llm_args.model_kwargs
config = checkpoint_loader.load_config(**load_config_kwargs)
# Store nvfp4 config in extra_attrs for Linear layer access
config.extra_attrs[
'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends
@ -373,9 +379,13 @@ class ModelLoader:
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)
# Allow overriding the number of layers via environment variable
# Note: This is kept for backward compatibility, but model_kwargs is preferred
num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM",
"0"))
if num_layers_override > 0:
logger.warning(
f"TLLM_OVERRIDE_LAYER_NUM is deprecated. Use model_kwargs instead: "
f"model_kwargs={{'num_hidden_layers': {num_layers_override}}}")
config.pretrained_config.num_hidden_layers = num_layers_override
for sub_config in ["text_config", "vision_config"]:
if hasattr(config.pretrained_config, sub_config):

View File

@ -1907,6 +1907,13 @@ class BaseLlmArgs(StrictBaseModel):
# Below are all remaining arguments
model_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description="Optional parameters overriding model config defaults. "
"Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. "
"Unknown keys are ignored",
status="prototype")
pipeline_parallel_size: int = Field(
default=1, description="The pipeline parallel size.")

View File

@ -227,6 +227,10 @@ methods:
annotation: Optional[Dict[str, str]]
default: null
status: prototype
model_kwargs:
annotation: Optional[Dict[str, Any]]
default: null
status: prototype
return_annotation: None
generate:
parameters:

View File

@ -138,6 +138,19 @@ max_seq_len: 128
assert llm_args.max_num_tokens == 256
assert llm_args.max_seq_len == 128
@pytest.mark.parametrize("llm_args_cls", [TrtLlmArgs, TorchLlmArgs])
def test_llm_args_with_model_kwargs(self, llm_args_cls):
yaml_content = """
model_kwargs:
num_hidden_layers: 2
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = llm_args_cls(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = llm_args_cls(**llm_args_dict)
assert llm_args.model_kwargs['num_hidden_layers'] == 2
def check_defaults(py_config_cls, pybind_config_cls):
py_config = py_config_cls()
@ -445,6 +458,18 @@ class TestTorchLlmArgs:
args = TorchLlmArgs(model=llama_model_path)
args.invalid_arg = 1
@print_traceback_on_error
def test_model_kwargs_with_num_hidden_layers(self):
"""Test that model_kwargs can override num_hidden_layers."""
from tensorrt_llm._torch.model_config import ModelConfig
config_no_kwargs = ModelConfig.from_pretrained(
llama_model_path).pretrained_config
model_kwargs = {'num_hidden_layers': 2}
config_with_kwargs = ModelConfig.from_pretrained(
llama_model_path, model_kwargs=model_kwargs).pretrained_config
assert config_no_kwargs.num_hidden_layers != config_with_kwargs.num_hidden_layers
assert config_with_kwargs.num_hidden_layers == 2
class TestTrtLlmArgs: