mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge be0949bcab into 6df2c8a074
This commit is contained in:
commit
c1f868f9ea
@ -121,7 +121,7 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
|
||||
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
|
||||
self.model_kwargs = deep_merge_dicts(
|
||||
self._model_defaults,
|
||||
self.model_kwargs,
|
||||
self.model_kwargs or {},
|
||||
)
|
||||
|
||||
# set sharding config source to huggingface
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generic, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||
|
||||
import filelock
|
||||
import torch
|
||||
@ -470,6 +470,56 @@ class ModelConfig(Generic[TConfig]):
|
||||
# Some checkpoints lack torch_dtype, populate with dtype
|
||||
pretrained_config.torch_dtype = getattr(pretrained_config, 'dtype',
|
||||
None)
|
||||
|
||||
# Apply model_kwargs to override config parameters if provided
|
||||
model_kwargs = kwargs.pop('model_kwargs', None)
|
||||
if model_kwargs:
|
||||
|
||||
def _recursive_update_config(config: transformers.PretrainedConfig,
|
||||
update_dict: Dict[str, Any]):
|
||||
"""
|
||||
Recursively update a PretrainedConfig object with values from update_dict.
|
||||
Args:
|
||||
config: PretrainedConfig object to update
|
||||
update_dict: Dictionary with values to update in the config
|
||||
"""
|
||||
for key, value_new in update_dict.items():
|
||||
if not hasattr(config, key):
|
||||
logger.warning(
|
||||
f"model_kwargs key '{key}' not found in pretrained_config, ignoring."
|
||||
)
|
||||
continue
|
||||
|
||||
target_value = getattr(config, key)
|
||||
|
||||
# Handle nested PretrainedConfig objects when value is a dict
|
||||
if isinstance(value_new, dict) and isinstance(
|
||||
target_value, transformers.PretrainedConfig):
|
||||
# Recursively update the nested config
|
||||
logger.info(
|
||||
f"Recursively updating nested config: {key}")
|
||||
_recursive_update_config(target_value, value_new)
|
||||
elif (key in ["torch_dtype", "dtype"]
|
||||
and isinstance(value_new, str)
|
||||
and value_new != "auto"):
|
||||
# check special handling of torch_dtype (DEPRECATED!) and dtype keys to ensure we
|
||||
# use the correct torch.dtype object instead of a string.
|
||||
dtype = getattr(torch, value_new)
|
||||
assert isinstance(dtype,
|
||||
torch.dtype), f"Invalid {dtype=}"
|
||||
setattr(config, key, dtype)
|
||||
logger.info(
|
||||
f"Applied model_kwargs: {key}={dtype} (previous value: {target_value})"
|
||||
)
|
||||
else:
|
||||
# Direct update for simple values
|
||||
setattr(config, key, value_new)
|
||||
logger.info(
|
||||
f"Applied model_kwargs: {key}={value_new} (previous value: {target_value})"
|
||||
)
|
||||
|
||||
_recursive_update_config(pretrained_config, model_kwargs)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
layer_quant_config = None
|
||||
moe_backend = kwargs.get('moe_backend', 'CUTLASS')
|
||||
|
||||
@ -8,6 +8,7 @@ from transformers import PretrainedConfig, WhisperConfig
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
@ -327,6 +328,15 @@ class MistralConfigLoader(BaseConfigLoader):
|
||||
block_size = (128, 128)
|
||||
quant_config.group_size = block_size[0]
|
||||
|
||||
# model_kwargs is not supported for Mistral format checkpoints
|
||||
# Extract it from kwargs to avoid passing to ModelConfig.__init__ (which doesn't accept it)
|
||||
model_kwargs = kwargs.pop("model_kwargs", None)
|
||||
if model_kwargs:
|
||||
logger.warning(
|
||||
"model_kwargs is not supported for Mistral format checkpoints. "
|
||||
f"Ignoring model_kwargs: {model_kwargs}"
|
||||
)
|
||||
|
||||
kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter
|
||||
model_config = ModelConfig(
|
||||
pretrained_config=pretrained_config,
|
||||
|
||||
@ -338,8 +338,8 @@ class ModelLoader:
|
||||
self, checkpoint_dir: str,
|
||||
checkpoint_loader: BaseCheckpointLoader) -> ModelConfig:
|
||||
"""Loads and validates the model configuration."""
|
||||
config = checkpoint_loader.load_config(
|
||||
checkpoint_dir,
|
||||
load_config_kwargs = dict(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
trust_remote_code=True,
|
||||
mapping=self.mapping,
|
||||
enable_min_latency=self.llm_args.enable_min_latency,
|
||||
@ -363,6 +363,12 @@ class ModelLoader:
|
||||
nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config.
|
||||
allowed_backends)
|
||||
|
||||
# Only pass model_kwargs if it's explicitly set (not None)
|
||||
if self.llm_args.model_kwargs is not None:
|
||||
load_config_kwargs['model_kwargs'] = self.llm_args.model_kwargs
|
||||
|
||||
config = checkpoint_loader.load_config(**load_config_kwargs)
|
||||
|
||||
# Store nvfp4 config in extra_attrs for Linear layer access
|
||||
config.extra_attrs[
|
||||
'nvfp4_gemm_allowed_backends'] = config.nvfp4_gemm_allowed_backends
|
||||
@ -373,9 +379,13 @@ class ModelLoader:
|
||||
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)
|
||||
|
||||
# Allow overriding the number of layers via environment variable
|
||||
# Note: This is kept for backward compatibility, but model_kwargs is preferred
|
||||
num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM",
|
||||
"0"))
|
||||
if num_layers_override > 0:
|
||||
logger.warning(
|
||||
f"TLLM_OVERRIDE_LAYER_NUM is deprecated. Use model_kwargs instead: "
|
||||
f"model_kwargs={{'num_hidden_layers': {num_layers_override}}}")
|
||||
config.pretrained_config.num_hidden_layers = num_layers_override
|
||||
for sub_config in ["text_config", "vision_config"]:
|
||||
if hasattr(config.pretrained_config, sub_config):
|
||||
|
||||
@ -1907,6 +1907,13 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
|
||||
# Below are all remaining arguments
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional parameters overriding model config defaults. "
|
||||
"Precedence: (1) model_kwargs, (2) model config file, (3) model config class defaults. "
|
||||
"Unknown keys are ignored",
|
||||
status="prototype")
|
||||
|
||||
pipeline_parallel_size: int = Field(
|
||||
default=1, description="The pipeline parallel size.")
|
||||
|
||||
|
||||
@ -227,6 +227,10 @@ methods:
|
||||
annotation: Optional[Dict[str, str]]
|
||||
default: null
|
||||
status: prototype
|
||||
model_kwargs:
|
||||
annotation: Optional[Dict[str, Any]]
|
||||
default: null
|
||||
status: prototype
|
||||
return_annotation: None
|
||||
generate:
|
||||
parameters:
|
||||
|
||||
@ -138,6 +138,19 @@ max_seq_len: 128
|
||||
assert llm_args.max_num_tokens == 256
|
||||
assert llm_args.max_seq_len == 128
|
||||
|
||||
@pytest.mark.parametrize("llm_args_cls", [TrtLlmArgs, TorchLlmArgs])
|
||||
def test_llm_args_with_model_kwargs(self, llm_args_cls):
|
||||
yaml_content = """
|
||||
model_kwargs:
|
||||
num_hidden_layers: 2
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
llm_args = llm_args_cls(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
llm_args = llm_args_cls(**llm_args_dict)
|
||||
assert llm_args.model_kwargs['num_hidden_layers'] == 2
|
||||
|
||||
|
||||
def check_defaults(py_config_cls, pybind_config_cls):
|
||||
py_config = py_config_cls()
|
||||
@ -445,6 +458,18 @@ class TestTorchLlmArgs:
|
||||
args = TorchLlmArgs(model=llama_model_path)
|
||||
args.invalid_arg = 1
|
||||
|
||||
@print_traceback_on_error
|
||||
def test_model_kwargs_with_num_hidden_layers(self):
|
||||
"""Test that model_kwargs can override num_hidden_layers."""
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
config_no_kwargs = ModelConfig.from_pretrained(
|
||||
llama_model_path).pretrained_config
|
||||
model_kwargs = {'num_hidden_layers': 2}
|
||||
config_with_kwargs = ModelConfig.from_pretrained(
|
||||
llama_model_path, model_kwargs=model_kwargs).pretrained_config
|
||||
assert config_no_kwargs.num_hidden_layers != config_with_kwargs.num_hidden_layers
|
||||
assert config_with_kwargs.num_hidden_layers == 2
|
||||
|
||||
|
||||
class TestTrtLlmArgs:
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user