[TRTLLM-8684][chore] Migrate BuildConfig to Pydantic, add a Python wrapper for KVCacheType enum (#8330)

Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
Anish Shanbhag 2025-10-28 09:17:26 -07:00 committed by GitHub
parent cdc9e5e645
commit a09b38a862
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 363 additions and 429 deletions

View File

@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer
import tensorrt_llm import tensorrt_llm
import tensorrt_llm.profiler as profiler import tensorrt_llm.profiler as profiler
from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.quantization import QuantMode from tensorrt_llm.quantization import QuantMode
@ -97,7 +97,7 @@ def TRTLLaMA(args, config):
quantization_config = pretrained_config['quantization'] quantization_config = pretrained_config['quantization']
build_config = config['build_config'] build_config = config['build_config']
kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type']) kv_cache_type = KVCacheType(build_config['kv_cache_type'])
plugin_config = build_config['plugin_config'] plugin_config = build_config['plugin_config']
dtype = pretrained_config['dtype'] dtype = pretrained_config['dtype']

View File

@ -27,7 +27,7 @@ from utils import add_common_args
import tensorrt_llm import tensorrt_llm
import tensorrt_llm.profiler as profiler import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger from tensorrt_llm import logger
from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.quantization import QuantMode from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner, from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner,
SamplingConfig, Session, TensorInfo) SamplingConfig, Session, TensorInfo)
@ -122,8 +122,7 @@ class QWenInfer(object):
num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
num_heads) num_heads)
if "kv_cache_type" in config["build_config"]: if "kv_cache_type" in config["build_config"]:
kv_cache_type = KVCacheType.from_string( kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
config["build_config"]["kv_cache_type"])
else: else:
kv_cache_type = KVCacheType.CONTINUOUS kv_cache_type = KVCacheType.CONTINUOUS

View File

@ -25,7 +25,7 @@ from vit_onnx_trt import Preprocss
import tensorrt_llm import tensorrt_llm
import tensorrt_llm.profiler as profiler import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger from tensorrt_llm import logger
from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.quantization import QuantMode from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session, from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session,
TensorInfo) TensorInfo)
@ -118,8 +118,7 @@ class QWenInfer(object):
num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
num_heads) num_heads)
if "kv_cache_type" in config["build_config"]: if "kv_cache_type" in config["build_config"]:
kv_cache_type = KVCacheType.from_string( kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
config["build_config"]["kv_cache_type"])
else: else:
kv_cache_type = KVCacheType.CONTINUOUS kv_cache_type = KVCacheType.CONTINUOUS

View File

@ -33,7 +33,8 @@ import tensorrt_llm
import tensorrt_llm.logger as logger import tensorrt_llm.logger as logger
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch) trt_dtype_to_torch)
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType from tensorrt_llm.bindings import GptJsonConfig
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
from tensorrt_llm.runtime.session import Session, TensorInfo from tensorrt_llm.runtime.session import Session, TensorInfo

View File

@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.models.modeling_utils import QuantConfig
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
from ...llmapi.utils import get_type_repr
from .models import ModelFactory, ModelFactoryRegistry from .models import ModelFactory, ModelFactoryRegistry
from .utils._config import DynamicYamlMixInForSettings from .utils._config import DynamicYamlMixInForSettings
from .utils.logger import ad_logger from .utils.logger import ad_logger
@ -318,12 +317,11 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
model_config = _get_config_dict() model_config = _get_config_dict()
build_config: Optional[object] = Field( build_config: Optional[BuildConfig] = Field(
default_factory=lambda: BuildConfig(), default_factory=BuildConfig,
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.", description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
exclude_from_json=True, exclude_from_json=True,
frozen=True, frozen=True,
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
repr=False, repr=False,
) )
backend: Literal["_autodeploy"] = Field( backend: Literal["_autodeploy"] = Field(

View File

@ -22,8 +22,8 @@ TUNED_QUANTS = {
QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES,
QuantAlgo.NO_QUANT, None QuantAlgo.NO_QUANT, None
} }
DEFAULT_MAX_BATCH_SIZE = BuildConfig.max_batch_size DEFAULT_MAX_BATCH_SIZE = BuildConfig.model_fields["max_batch_size"].default
DEFAULT_MAX_NUM_TOKENS = BuildConfig.max_num_tokens DEFAULT_MAX_NUM_TOKENS = BuildConfig.model_fields["max_num_tokens"].default
def get_benchmark_engine_settings( def get_benchmark_engine_settings(

View File

@ -12,27 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import dataclasses
import json import json
import math import math
import os import os
import shutil import shutil
import time import time
from dataclasses import dataclass, field
from functools import cache
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import numpy as np import numpy as np
import tensorrt as trt import tensorrt as trt
from pydantic import BaseModel, Field
from ._common import _is_building, check_max_num_tokens, serialize_engine from ._common import _is_building, check_max_num_tokens, serialize_engine
from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt, from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt,
to_json_file, trt_gte) to_json_file, trt_gte)
from .bindings import KVCacheType
from .functional import PositionEmbeddingType from .functional import PositionEmbeddingType
from .graph_rewriting import optimize from .graph_rewriting import optimize
from .llmapi.kv_cache_type import KVCacheType
from .logger import logger from .logger import logger
from .lora_helper import LoraConfig from .lora_helper import LoraConfig
from .models import PretrainedConfig, PretrainedModel from .models import PretrainedConfig, PretrainedModel
@ -46,10 +43,7 @@ from .version import __version__
class ConfigEncoder(json.JSONEncoder): class ConfigEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, KVCacheType): if hasattr(obj, 'model_dump'):
# For KVCacheType, convert it to string by split of 'KVCacheType.PAGED'.
return obj.__str__().split('.')[-1]
elif hasattr(obj, 'model_dump'):
# Handle Pydantic models (including DecodingBaseConfig and subclasses) # Handle Pydantic models (including DecodingBaseConfig and subclasses)
return obj.model_dump(mode='json') return obj.model_dump(mode='json')
else: else:
@ -456,75 +450,112 @@ class Builder():
logger.info(f'Config saved to {config_path}.') logger.info(f'Config saved to {config_path}.')
@dataclass class BuildConfig(BaseModel):
class BuildConfig:
"""Configuration class for TensorRT LLM engine building parameters. """Configuration class for TensorRT LLM engine building parameters.
This class contains all the configuration parameters needed to build a TensorRT LLM engine, This class contains all the configuration parameters needed to build a TensorRT LLM engine,
including sequence length limits, batch sizes, optimization settings, and various features. including sequence length limits, batch sizes, optimization settings, and various features.
Args:
max_input_len (int): Maximum length of input sequences. Defaults to 1024.
max_seq_len (int, optional): The maximum possible sequence length for a single request, including both input and generated output tokens. Defaults to None.
opt_batch_size (int): Optimal batch size for engine optimization. Defaults to 8.
max_batch_size (int): Maximum batch size the engine can handle. Defaults to 2048.
max_beam_width (int): Maximum beam width for beam search decoding. Defaults to 1.
max_num_tokens (int): Maximum number of batched input tokens after padding is removed in each batch. Defaults to 8192.
opt_num_tokens (int, optional): Optimal number of batched input tokens for engine optimization. Defaults to None.
max_prompt_embedding_table_size (int): Maximum size of prompt embedding table for prompt tuning. Defaults to 0.
kv_cache_type (KVCacheType, optional): Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED. Defaults to None.
gather_context_logits (int): Whether to gather logits during context phase. Defaults to False.
gather_generation_logits (int): Whether to gather logits during generation phase. Defaults to False.
strongly_typed (bool): Whether to use strongly_typed. Defaults to True.
force_num_profiles (int, optional): Force a specific number of optimization profiles. If None, auto-determined. Defaults to None.
profiling_verbosity (str): Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none'). Defaults to 'layer_names_only'.
enable_debug_output (bool): Whether to enable debug output during building. Defaults to False.
max_draft_len (int): Maximum length of draft tokens for speculative decoding. Defaults to 0.
speculative_decoding_mode (SpeculativeDecodingMode): Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.). Defaults to SpeculativeDecodingMode.NONE.
use_refit (bool): Whether to enable engine refitting capabilities. Defaults to False.
input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None.
output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'.
lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig.
weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False.
weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False.
plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig.
use_strip_plan (bool): Whether to use stripped plan for engine building. Defaults to False.
max_encoder_input_len (int): Maximum encoder input length for encoder-decoder models. Defaults to 1024.
dry_run (bool): Whether to perform a dry run without actually building the engine. Defaults to False.
visualize_network (str, optional): Path to save network visualization. If None, no visualization generated. Defaults to None.
monitor_memory (bool): Whether to monitor memory usage during building. Defaults to False.
use_mrope (bool): Whether to use Multi-RoPE (Rotary Position Embedding) optimization. Defaults to False.
""" """
max_input_len: int = 1024 max_input_len: int = Field(default=1024,
max_seq_len: int = None description="Maximum length of input sequences.")
opt_batch_size: int = 8 max_seq_len: Optional[int] = Field(
max_batch_size: int = 2048 default=None,
max_beam_width: int = 1 description=
max_num_tokens: int = 8192 "The maximum possible sequence length for a single request, including both input and generated "
opt_num_tokens: Optional[int] = None "output tokens.")
max_prompt_embedding_table_size: int = 0 opt_batch_size: int = Field(
kv_cache_type: KVCacheType = None default=8, description="Optimal batch size for engine optimization.")
gather_context_logits: int = False max_batch_size: int = Field(
gather_generation_logits: int = False default=2048, description="Maximum batch size the engine can handle.")
strongly_typed: bool = True max_beam_width: int = Field(
force_num_profiles: Optional[int] = None default=1, description="Maximum beam width for beam search decoding.")
profiling_verbosity: str = 'layer_names_only' max_num_tokens: int = Field(
enable_debug_output: bool = False default=8192,
max_draft_len: int = 0 description="Maximum number of batched input tokens after padding is "
speculative_decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE "removed in each batch.")
use_refit: bool = False opt_num_tokens: Optional[int] = Field(
input_timing_cache: str = None default=None,
output_timing_cache: str = 'model.cache' description=
lora_config: LoraConfig = field(default_factory=LoraConfig) "Optimal number of batched input tokens for engine optimization.")
weight_sparsity: bool = False max_prompt_embedding_table_size: int = Field(
weight_streaming: bool = False default=0,
plugin_config: PluginConfig = field(default_factory=PluginConfig) description="Maximum size of prompt embedding table for prompt tuning.")
use_strip_plan: bool = False kv_cache_type: Optional[KVCacheType] = Field(
max_encoder_input_len: int = 1024 # for enc-dec DecoderModel default=None,
dry_run: bool = False description=
visualize_network: str = None "Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED."
monitor_memory: bool = False )
use_mrope: bool = False gather_context_logits: bool = Field(
default=False,
description="Whether to gather logits during context phase.")
gather_generation_logits: bool = Field(
default=False,
description="Whether to gather logits during generation phase.")
strongly_typed: bool = Field(default=True,
description="Whether to use strongly_typed.")
force_num_profiles: Optional[int] = Field(
default=None,
description=
"Force a specific number of optimization profiles. If None, auto-determined."
)
profiling_verbosity: str = Field(
default='layer_names_only',
description=
"Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none')."
)
enable_debug_output: bool = Field(
default=False,
description="Whether to enable debug output during building.")
max_draft_len: int = Field(
default=0,
description="Maximum length of draft tokens for speculative decoding.")
speculative_decoding_mode: SpeculativeDecodingMode = Field(
default=SpeculativeDecodingMode.NONE,
description="Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.)."
)
use_refit: bool = Field(
default=False,
description="Whether to enable engine refitting capabilities.")
input_timing_cache: Optional[str] = Field(
default=None,
description=
"Path to input timing cache file. If None, no input cache used.")
output_timing_cache: str = Field(
default='model.cache', description="Path to output timing cache file.")
lora_config: LoraConfig = Field(
default_factory=LoraConfig,
description="Configuration for LoRA (Low-Rank Adaptation) fine-tuning.")
weight_sparsity: bool = Field(
default=False,
description="Whether to enable weight sparsity optimization.")
weight_streaming: bool = Field(
default=False,
description="Whether to enable weight streaming for large models.")
plugin_config: PluginConfig = Field(
default_factory=PluginConfig,
description="Configuration for TensorRT LLM plugins.")
use_strip_plan: bool = Field(
default=False,
description="Whether to use stripped plan for engine building.")
max_encoder_input_len: int = Field(
default=1024,
description="Maximum encoder input length for encoder-decoder models.")
dry_run: bool = Field(
default=False,
description=
"Whether to perform a dry run without actually building the engine.")
visualize_network: Optional[str] = Field(
default=None,
description=
"Path to save network visualization. If None, no visualization generated."
)
monitor_memory: bool = Field(
default=False,
description="Whether to monitor memory usage during building.")
use_mrope: bool = Field(
default=False,
description=
"Whether to use Multi-RoPE (Rotary Position Embedding) optimization.")
# Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future), # Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),
# we need to handle it given model architecture. # we need to handle it given model architecture.
@ -574,144 +605,10 @@ class BuildConfig:
override_attri('paged_state', False) override_attri('paged_state', False)
@classmethod @classmethod
@cache def from_json_file(cls, config_file):
def get_build_config_defaults(cls):
return {
field.name: field.default
for field in dataclasses.fields(cls)
if field.default is not dataclasses.MISSING
}
@classmethod
def from_dict(cls, config, plugin_config=None):
config = copy.deepcopy(
config
) # it just does not make sense to change the input arg `config`
defaults = cls.get_build_config_defaults()
max_input_len = config.pop('max_input_len',
defaults.get('max_input_len'))
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
max_batch_size = config.pop('max_batch_size',
defaults.get('max_batch_size'))
max_beam_width = config.pop('max_beam_width',
defaults.get('max_beam_width'))
max_num_tokens = config.pop('max_num_tokens',
defaults.get('max_num_tokens'))
opt_num_tokens = config.pop('opt_num_tokens',
defaults.get('opt_num_tokens'))
opt_batch_size = config.pop('opt_batch_size',
defaults.get('opt_batch_size'))
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size',
defaults.get('max_prompt_embedding_table_size'))
if "kv_cache_type" in config and config["kv_cache_type"] is not None:
kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type'))
else:
kv_cache_type = None
gather_context_logits = config.pop(
'gather_context_logits', defaults.get('gather_context_logits'))
gather_generation_logits = config.pop(
'gather_generation_logits',
defaults.get('gather_generation_logits'))
strongly_typed = config.pop('strongly_typed',
defaults.get('strongly_typed'))
force_num_profiles = config.pop('force_num_profiles',
defaults.get('force_num_profiles'))
weight_sparsity = config.pop('weight_sparsity',
defaults.get('weight_sparsity'))
profiling_verbosity = config.pop('profiling_verbosity',
defaults.get('profiling_verbosity'))
enable_debug_output = config.pop('enable_debug_output',
defaults.get('enable_debug_output'))
max_draft_len = config.pop('max_draft_len',
defaults.get('max_draft_len'))
speculative_decoding_mode = config.pop(
'speculative_decoding_mode',
defaults.get('speculative_decoding_mode'))
use_refit = config.pop('use_refit', defaults.get('use_refit'))
input_timing_cache = config.pop('input_timing_cache',
defaults.get('input_timing_cache'))
output_timing_cache = config.pop('output_timing_cache',
defaults.get('output_timing_cache'))
lora_config = LoraConfig(**config.get('lora_config', {}))
max_encoder_input_len = config.pop(
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
weight_streaming = config.pop('weight_streaming',
defaults.get('weight_streaming'))
use_strip_plan = config.pop('use_strip_plan',
defaults.get('use_strip_plan'))
if plugin_config is None:
plugin_config = PluginConfig()
if "plugin_config" in config.keys():
plugin_config = plugin_config.model_copy(
update=config["plugin_config"], deep=True)
dry_run = config.pop('dry_run', defaults.get('dry_run'))
visualize_network = config.pop('visualize_network',
defaults.get('visualize_network'))
monitor_memory = config.pop('monitor_memory',
defaults.get('monitor_memory'))
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))
return cls(
max_input_len=max_input_len,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
opt_batch_size=opt_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
kv_cache_type=kv_cache_type,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
strongly_typed=strongly_typed,
force_num_profiles=force_num_profiles,
profiling_verbosity=profiling_verbosity,
enable_debug_output=enable_debug_output,
max_draft_len=max_draft_len,
speculative_decoding_mode=speculative_decoding_mode,
use_refit=use_refit,
input_timing_cache=input_timing_cache,
output_timing_cache=output_timing_cache,
lora_config=lora_config,
use_strip_plan=use_strip_plan,
max_encoder_input_len=max_encoder_input_len,
weight_sparsity=weight_sparsity,
weight_streaming=weight_streaming,
plugin_config=plugin_config,
dry_run=dry_run,
visualize_network=visualize_network,
monitor_memory=monitor_memory,
use_mrope=use_mrope)
@classmethod
def from_json_file(cls, config_file, plugin_config=None):
with open(config_file) as f: with open(config_file) as f:
config = json.load(f) config = json.load(f)
return BuildConfig.from_dict(config, plugin_config=plugin_config) return BuildConfig(**config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
# the enum KVCacheType cannot be converted automatically
if output.get('kv_cache_type', None) is not None:
output['kv_cache_type'] = str(output['kv_cache_type'].name)
output['plugin_config'] = output['plugin_config'].model_dump()
output['lora_config'] = output['lora_config'].model_dump()
return output
def update_from_dict(self, config: dict):
for name, value in config.items():
if not hasattr(self, name):
raise AttributeError(
f"{self.__class__} object has no attribute {name}")
setattr(self, name, value)
def update(self, **kwargs):
self.update_from_dict(kwargs)
class EngineConfig: class EngineConfig:
@ -731,11 +628,10 @@ class EngineConfig:
def from_json_str(cls, config_str): def from_json_str(cls, config_str):
config = json.loads(config_str) config = json.loads(config_str)
return cls(PretrainedConfig.from_dict(config['pretrained_config']), return cls(PretrainedConfig.from_dict(config['pretrained_config']),
BuildConfig.from_dict(config['build_config']), BuildConfig(**config['build_config']), config['version'])
config['version'])
def to_dict(self): def to_dict(self):
build_config = self.build_config.to_dict() build_config = self.build_config.model_dump(mode="json")
build_config.pop('dry_run', None) # Not an Engine Characteristic build_config.pop('dry_run', None) # Not an Engine Characteristic
build_config.pop('visualize_network', build_config.pop('visualize_network',
None) # Not an Engine Characteristic None) # Not an Engine Characteristic
@ -1081,7 +977,7 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
''' '''
tic = time.time() tic = time.time()
# avoid changing the input config # avoid changing the input config
build_config = copy.deepcopy(build_config) build_config = build_config.model_copy(deep=True)
build_config.plugin_config.dtype = model.config.dtype build_config.plugin_config.dtype = model.config.dtype
build_config.update_kv_cache_type(model.config.architecture) build_config.update_kv_cache_type(model.config.architecture)

View File

@ -26,8 +26,8 @@ import torch
from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier, from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier,
mpi_comm, mpi_rank, mpi_world_size) mpi_comm, mpi_rank, mpi_world_size)
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.lora_manager import LoraManager
@ -37,23 +37,6 @@ from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
from tensorrt_llm.quantization.mode import QuantAlgo from tensorrt_llm.quantization.mode import QuantAlgo
def enum_type(enum_class):
def parse_enum(value):
if isinstance(value, enum_class):
return value
if isinstance(value, str):
return enum_class.from_string(value)
valid_values = [e.name for e in enum_class]
raise argparse.ArgumentTypeError(
f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}"
)
return parse_enum
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -92,29 +75,30 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
'--max_batch_size', '--max_batch_size',
type=int, type=int,
default=BuildConfig.max_batch_size, default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.") help="Maximum number of requests that the engine can schedule.")
parser.add_argument('--max_input_len', parser.add_argument(
type=int, '--max_input_len',
default=BuildConfig.max_input_len, type=int,
help="Maximum input length of one request.") default=BuildConfig.model_fields["max_input_len"].default,
help="Maximum input length of one request.")
parser.add_argument( parser.add_argument(
'--max_seq_len', '--max_seq_len',
'--max_decoder_seq_len', '--max_decoder_seq_len',
dest='max_seq_len', dest='max_seq_len',
type=int, type=int,
default=BuildConfig.max_seq_len, default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. " help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.") "If unspecified, the value is deduced from the model config.")
parser.add_argument( parser.add_argument(
'--max_beam_width', '--max_beam_width',
type=int, type=int,
default=BuildConfig.max_beam_width, default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.") help="Maximum number of beams for beam search decoding.")
parser.add_argument( parser.add_argument(
'--max_num_tokens', '--max_num_tokens',
type=int, type=int,
default=BuildConfig.max_num_tokens, default=BuildConfig.model_fields["max_num_tokens"].default,
help= help=
"Maximum number of batched input tokens after padding is removed in each batch. " "Maximum number of batched input tokens after padding is removed in each batch. "
"Currently, the input padding is removed by default; " "Currently, the input padding is removed by default; "
@ -123,7 +107,7 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
'--opt_num_tokens', '--opt_num_tokens',
type=int, type=int,
default=BuildConfig.opt_num_tokens, default=BuildConfig.model_fields["opt_num_tokens"].default,
help= help=
"Optimal number of batched input tokens after padding is removed in each batch " "Optimal number of batched input tokens after padding is removed in each batch "
"It equals to ``max_batch_size * max_beam_width`` by default, set this " "It equals to ``max_batch_size * max_beam_width`` by default, set this "
@ -132,7 +116,7 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
'--max_encoder_input_len', '--max_encoder_input_len',
type=int, type=int,
default=BuildConfig.max_encoder_input_len, default=BuildConfig.model_fields["max_encoder_input_len"].default,
help="Maximum encoder input length for enc-dec models. " help="Maximum encoder input length for enc-dec models. "
"Set ``max_input_len`` to 1 to start generation from decoder_start_token_id of length 1." "Set ``max_input_len`` to 1 to start generation from decoder_start_token_id of length 1."
) )
@ -140,14 +124,15 @@ def parse_arguments():
'--max_prompt_embedding_table_size', '--max_prompt_embedding_table_size',
'--max_multimodal_len', '--max_multimodal_len',
type=int, type=int,
default=BuildConfig.max_prompt_embedding_table_size, default=BuildConfig.model_fields["max_prompt_embedding_table_size"].
default,
help= help=
"Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. " "Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. "
"Setting a value > 0 enables prompt tuning or multimodal input.") "Setting a value > 0 enables prompt tuning or multimodal input.")
parser.add_argument( parser.add_argument(
'--kv_cache_type', '--kv_cache_type',
default=argparse.SUPPRESS, default=argparse.SUPPRESS,
type=enum_type(KVCacheType), type=KVCacheType,
help= help=
"Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed." "Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed."
) )
@ -156,42 +141,44 @@ def parse_arguments():
type=str, type=str,
default=argparse.SUPPRESS, default=argparse.SUPPRESS,
help= help=
"Deprecated. Enabling this option is equvilient to ``--kv_cache_type paged`` for transformer based models." "Deprecated. Enabling this option is equivalent to ``--kv_cache_type paged`` for transformer based models."
) )
parser.add_argument( parser.add_argument(
'--input_timing_cache', '--input_timing_cache',
type=str, type=str,
default=BuildConfig.input_timing_cache, default=BuildConfig.model_fields["input_timing_cache"].default,
help= help=
"The file path to read the timing cache. This option is ignored if the file does not exist." "The file path to read the timing cache. This option is ignored if the file does not exist."
) )
parser.add_argument('--output_timing_cache', parser.add_argument(
type=str, '--output_timing_cache',
default=BuildConfig.output_timing_cache, type=str,
help="The file path to write the timing cache.") default=BuildConfig.model_fields["output_timing_cache"].default,
help="The file path to write the timing cache.")
parser.add_argument( parser.add_argument(
'--profiling_verbosity', '--profiling_verbosity',
type=str, type=str,
default=BuildConfig.profiling_verbosity, default=BuildConfig.model_fields["profiling_verbosity"].default,
choices=['layer_names_only', 'detailed', 'none'], choices=['layer_names_only', 'detailed', 'none'],
help= help=
"The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters." "The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters."
) )
parser.add_argument( parser.add_argument(
'--strip_plan', '--strip_plan',
default=BuildConfig.use_strip_plan, default=BuildConfig.model_fields["use_strip_plan"].default,
action='store_true', action='store_true',
help= help=
"Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time." "Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time."
) )
parser.add_argument('--weight_sparsity', parser.add_argument(
default=BuildConfig.weight_sparsity, '--weight_sparsity',
action='store_true', default=BuildConfig.model_fields["weight_sparsity"].default,
help="Enable weight sparsity.") action='store_true',
help="Enable weight sparsity.")
parser.add_argument( parser.add_argument(
'--weight_streaming', '--weight_streaming',
default=BuildConfig.weight_streaming, default=BuildConfig.model_fields["weight_streaming"].default,
action='store_true', action='store_true',
help= help=
"Enable offloading weights to CPU and streaming loading at runtime.", "Enable offloading weights to CPU and streaming loading at runtime.",
@ -213,10 +200,11 @@ def parse_arguments():
default='info', default='info',
choices=severity_map.keys(), choices=severity_map.keys(),
help="The logging level.") help="The logging level.")
parser.add_argument('--enable_debug_output', parser.add_argument(
default=BuildConfig.enable_debug_output, '--enable_debug_output',
action='store_true', default=BuildConfig.model_fields["enable_debug_output"].default,
help="Enable debug output.") action='store_true',
help="Enable debug output.")
parser.add_argument( parser.add_argument(
'--visualize_network', '--visualize_network',
type=str, type=str,
@ -226,7 +214,7 @@ def parse_arguments():
) )
parser.add_argument( parser.add_argument(
'--dry_run', '--dry_run',
default=BuildConfig.dry_run, default=BuildConfig.model_fields["dry_run"].default,
action='store_true', action='store_true',
help= help=
"Run through the build process except the actual Engine build for debugging." "Run through the build process except the actual Engine build for debugging."
@ -519,65 +507,37 @@ def main():
f"Overriding # of builder profiles <= {force_num_profiles_from_env}." f"Overriding # of builder profiles <= {force_num_profiles_from_env}."
) )
build_config = BuildConfig.from_dict( build_config = BuildConfig(
{ max_input_len=args.max_input_len,
'max_input_len': max_seq_len=args.max_seq_len,
args.max_input_len, max_batch_size=args.max_batch_size,
'max_seq_len': max_beam_width=args.max_beam_width,
args.max_seq_len, max_num_tokens=args.max_num_tokens,
'max_batch_size': opt_num_tokens=args.opt_num_tokens,
args.max_batch_size, max_prompt_embedding_table_size=args.
'max_beam_width': max_prompt_embedding_table_size,
args.max_beam_width, kv_cache_type=getattr(args, "kv_cache_type", None),
'max_num_tokens': gather_context_logits=args.gather_context_logits,
args.max_num_tokens, gather_generation_logits=args.gather_generation_logits,
'opt_num_tokens': strongly_typed=True,
args.opt_num_tokens, force_num_profiles=force_num_profiles_from_env,
'max_prompt_embedding_table_size': weight_sparsity=args.weight_sparsity,
args.max_prompt_embedding_table_size, profiling_verbosity=args.profiling_verbosity,
'gather_context_logits': enable_debug_output=args.enable_debug_output,
args.gather_context_logits, max_draft_len=args.max_draft_len,
'gather_generation_logits': speculative_decoding_mode=speculative_decoding_mode,
args.gather_generation_logits, input_timing_cache=args.input_timing_cache,
'strongly_typed': output_timing_cache=args.output_timing_cache,
True, dry_run=args.dry_run,
'force_num_profiles': visualize_network=args.visualize_network,
force_num_profiles_from_env, max_encoder_input_len=args.max_encoder_input_len,
'weight_sparsity': weight_streaming=args.weight_streaming,
args.weight_sparsity, monitor_memory=args.monitor_memory,
'profiling_verbosity': use_mrope=getattr(model_config, "qwen_type", None) == "qwen2_vl",
args.profiling_verbosity,
'enable_debug_output':
args.enable_debug_output,
'max_draft_len':
args.max_draft_len,
'speculative_decoding_mode':
speculative_decoding_mode,
'input_timing_cache':
args.input_timing_cache,
'output_timing_cache':
args.output_timing_cache,
'dry_run':
args.dry_run,
'visualize_network':
args.visualize_network,
'max_encoder_input_len':
args.max_encoder_input_len,
'weight_streaming':
args.weight_streaming,
'monitor_memory':
args.monitor_memory,
'use_mrope':
(True if model_config.qwen_type == "qwen2_vl" else False)
if hasattr(model_config, "qwen_type") else False
},
plugin_config=plugin_config) plugin_config=plugin_config)
if hasattr(args, 'kv_cache_type'):
build_config.update_from_dict({'kv_cache_type': args.kv_cache_type})
else: else:
build_config = BuildConfig.from_json_file(args.build_config, build_config = BuildConfig.from_json_file(args.build_config)
plugin_config=plugin_config) build_config.plugin_config = plugin_config
parallel_build(model_config, ckpt_dir, build_config, args.output_dir, parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
workers, args.log_level, model_cls, **kwargs) workers, args.log_level, model_cls, **kwargs)

View File

@ -50,23 +50,23 @@ from ..logger import logger, severity_map
help="The logging level.") help="The logging level.")
@click.option("--max_beam_width", @click.option("--max_beam_width",
type=int, type=int,
default=BuildConfig.max_beam_width, default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.") help="Maximum number of beams for beam search decoding.")
@click.option("--max_batch_size", @click.option("--max_batch_size",
type=int, type=int,
default=BuildConfig.max_batch_size, default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.") help="Maximum number of requests that the engine can schedule.")
@click.option( @click.option(
"--max_num_tokens", "--max_num_tokens",
type=int, type=int,
default=BuildConfig.max_num_tokens, default=BuildConfig.model_fields["max_num_tokens"].default,
help= help=
"Maximum number of batched input tokens after padding is removed in each batch." "Maximum number of batched input tokens after padding is removed in each batch."
) )
@click.option( @click.option(
"--max_seq_len", "--max_seq_len",
type=int, type=int,
default=BuildConfig.max_seq_len, default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. " help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.") "If unspecified, the value is deduced from the model config.")
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.') @click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')

View File

@ -2,7 +2,6 @@
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint. Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
''' '''
import argparse import argparse
import copy
import json import json
import os import os
import re import re
@ -57,7 +56,7 @@ def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
# There are weights preprocess during optimize model. # There are weights preprocess during optimize model.
tik = time.time() tik = time.time()
build_config = copy.deepcopy(engine_config.build_config) build_config = engine_config.build_config.model_copy(deep=True)
optimize_model_with_config(model, build_config) optimize_model_with_config(model, build_config)
tok = time.time() tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))

View File

@ -75,25 +75,29 @@ def _signal_handler_cleanup_child(signum, frame):
sys.exit(128 + signum) sys.exit(128 + signum)
def get_llm_args(model: str, def get_llm_args(
tokenizer: Optional[str] = None, model: str,
backend: str = "pytorch", tokenizer: Optional[str] = None,
max_beam_width: int = BuildConfig.max_beam_width, backend: str = "pytorch",
max_batch_size: int = BuildConfig.max_batch_size, max_beam_width: int = BuildConfig.model_fields["max_beam_width"].
max_num_tokens: int = BuildConfig.max_num_tokens, default,
max_seq_len: int = BuildConfig.max_seq_len, max_batch_size: int = BuildConfig.model_fields["max_batch_size"].
tensor_parallel_size: int = 1, default,
pipeline_parallel_size: int = 1, max_num_tokens: int = BuildConfig.model_fields["max_num_tokens"].
moe_expert_parallel_size: Optional[int] = None, default,
gpus_per_node: Optional[int] = None, max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default,
free_gpu_memory_fraction: float = 0.9, tensor_parallel_size: int = 1,
num_postprocess_workers: int = 0, pipeline_parallel_size: int = 1,
trust_remote_code: bool = False, moe_expert_parallel_size: Optional[int] = None,
reasoning_parser: Optional[str] = None, gpus_per_node: Optional[int] = None,
fail_fast_on_attention_window_too_large: bool = False, free_gpu_memory_fraction: float = 0.9,
otlp_traces_endpoint: Optional[str] = None, num_postprocess_workers: int = 0,
enable_chunked_prefill: bool = False, trust_remote_code: bool = False,
**llm_args_extra_dict: Any): reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
enable_chunked_prefill: bool = False,
**llm_args_extra_dict: Any):
if gpus_per_node is None: if gpus_per_node is None:
gpus_per_node = device_count() gpus_per_node = device_count()
@ -242,23 +246,23 @@ class ChoiceWithAlias(click.Choice):
help="The logging level.") help="The logging level.")
@click.option("--max_beam_width", @click.option("--max_beam_width",
type=int, type=int,
default=BuildConfig.max_beam_width, default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.") help="Maximum number of beams for beam search decoding.")
@click.option("--max_batch_size", @click.option("--max_batch_size",
type=int, type=int,
default=BuildConfig.max_batch_size, default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.") help="Maximum number of requests that the engine can schedule.")
@click.option( @click.option(
"--max_num_tokens", "--max_num_tokens",
type=int, type=int,
default=BuildConfig.max_num_tokens, default=BuildConfig.model_fields["max_num_tokens"].default,
help= help=
"Maximum number of batched input tokens after padding is removed in each batch." "Maximum number of batched input tokens after padding is removed in each batch."
) )
@click.option( @click.option(
"--max_seq_len", "--max_seq_len",
type=int, type=int,
default=BuildConfig.max_seq_len, default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. " help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.") "If unspecified, the value is deduced from the model config.")
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.') @click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')
@ -436,7 +440,7 @@ def serve(
help="The logging level.") help="The logging level.")
@click.option("--max_batch_size", @click.option("--max_batch_size",
type=int, type=int,
default=BuildConfig.max_batch_size, default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.") help="Maximum number of requests that the engine can schedule.")
@click.option( @click.option(
"--max_num_tokens", "--max_num_tokens",

View File

@ -104,7 +104,7 @@ class BuildCache:
Get the build step for engine building. Get the build step for engine building.
''' '''
build_config_str = json.dumps(self.prune_build_config_for_cache_key( build_config_str = json.dumps(self.prune_build_config_for_cache_key(
build_config.to_dict()), build_config.model_dump(mode="json")),
sort_keys=True) sort_keys=True)
kwargs_str = json.dumps(kwargs, sort_keys=True) kwargs_str = json.dumps(kwargs, sort_keys=True)

View File

@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import tensorrt_llm.bindings as _bindings
class KVCacheType(str, Enum):
"""Python enum wrapper for KVCacheType.
This is a pure Python enum that mirrors the C++ KVCacheType enum exposed
through pybind11.
"""
CONTINUOUS = "continuous"
PAGED = "paged"
DISABLED = "disabled"
@classmethod
def _missing_(cls, value):
"""Allow case-insensitive string values to be converted to enum members."""
if isinstance(value, str):
for member in cls:
if member.value.lower() == value.lower():
return member
return None
def to_cpp(self) -> '_bindings.KVCacheType':
import tensorrt_llm.bindings as _bindings
return getattr(_bindings.KVCacheType, self.name)
@classmethod
def from_cpp(cls, cpp_enum) -> 'KVCacheType':
# C++ enum's __str__ returns "KVCacheType.PAGED", extract the name
name = str(cpp_enum).split('.')[-1]
return cls(name)

View File

@ -1,5 +1,4 @@
import ast import ast
import copy
import functools import functools
import json import json
import math import math
@ -1764,17 +1763,6 @@ class BaseLlmArgs(StrictBaseModel):
ret = cls(**kwargs) ret = cls(**kwargs)
return ret return ret
def to_dict(self) -> dict:
"""Dump `LlmArgs` instance to a dict.
Returns:
dict: The dict that contains all fields of the `LlmArgs` instance.
"""
model_dict = self.model_dump(mode='json')
# TODO: the BuildConfig.to_dict and from_dict don't work well with pydantic
model_dict['build_config'] = copy.deepcopy(self.build_config)
return model_dict
@staticmethod @staticmethod
def _check_consistency(kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: def _check_consistency(kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
# max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during # max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during
@ -1919,10 +1907,6 @@ class BaseLlmArgs(StrictBaseModel):
if self.max_input_len: if self.max_input_len:
kwargs["max_input_len"] = self.max_input_len kwargs["max_input_len"] = self.max_input_len
self.build_config = BuildConfig(**kwargs) self.build_config = BuildConfig(**kwargs)
else:
assert isinstance(
build_config,
BuildConfig), f"build_config is not initialized: {build_config}"
return self return self
@model_validator(mode="after") @model_validator(mode="after")
@ -2001,7 +1985,7 @@ class BaseLlmArgs(StrictBaseModel):
# TODO: remove the checker when manage weights support all data types # TODO: remove the checker when manage weights support all data types
if is_trt_llm_args and self.fast_build and (self.quant_config.quant_algo if is_trt_llm_args and self.fast_build and (self.quant_config.quant_algo
is QuantAlgo.FP8): is QuantAlgo.FP8):
self._update_plugin_config("manage_weights", True) self.build_config.plugin_config.manage_weights = True
if self.parallel_config.world_size == 1 and self.build_config: if self.parallel_config.world_size == 1 and self.build_config:
self.build_config.plugin_config.nccl_plugin = None self.build_config.plugin_config.nccl_plugin = None
@ -2166,9 +2150,6 @@ class BaseLlmArgs(StrictBaseModel):
"while LoRA prefetch is not supported") "while LoRA prefetch is not supported")
return self return self
def _update_plugin_config(self, key: str, value: Any):
setattr(self.build_config.plugin_config, key, value)
def _load_config_from_engine(self, engine_dir: Path): def _load_config_from_engine(self, engine_dir: Path):
engine_config = EngineConfig.from_json_file(engine_dir / "config.json") engine_config = EngineConfig.from_json_file(engine_dir / "config.json")
self._pretrained_config = engine_config.pretrained_config self._pretrained_config = engine_config.pretrained_config
@ -2271,10 +2252,8 @@ class TrtLlmArgs(BaseLlmArgs):
fast_build: bool = Field(default=False, description="Enable fast build.") fast_build: bool = Field(default=False, description="Enable fast build.")
# BuildConfig is introduced to give users a familiar interface to configure the model building. # BuildConfig is introduced to give users a familiar interface to configure the model building.
build_config: Optional[object] = Field( build_config: Optional[BuildConfig] = Field(default=None,
default=None, description="Build config.")
description="Build config.",
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"})
# Prompt adapter arguments # Prompt adapter arguments
enable_prompt_adapter: bool = Field(default=False, enable_prompt_adapter: bool = Field(default=False,
@ -2405,11 +2384,10 @@ class TorchCompileConfig(StrictBaseModel):
class TorchLlmArgs(BaseLlmArgs): class TorchLlmArgs(BaseLlmArgs):
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs # Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
build_config: Optional[object] = Field( build_config: Optional[BuildConfig] = Field(
default=None, default=None,
description="Build config.", description="Build config.",
exclude_from_json=True, exclude_from_json=True,
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
status="deprecated", status="deprecated",
) )
@ -2911,10 +2889,7 @@ def update_llm_args_with_extra_dict(
for field_name, field_type in field_mapping.items(): for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict: if field_name in llm_args_dict:
# Some fields need to be converted manually. # Some fields need to be converted manually.
if field_name in [ if field_name in ["speculative_config", "sparse_attention_config"]:
"speculative_config", "build_config",
"sparse_attention_config"
]:
llm_args_dict[field_name] = field_type.from_dict( llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name]) llm_args_dict[field_name])
else: else:
@ -2928,6 +2903,10 @@ def update_llm_args_with_extra_dict(
# For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch # For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch
# backend, overwriting the knobs there since build_config always has the highest priority # backend, overwriting the knobs there since build_config always has the highest priority
if "build_config" in llm_args: if "build_config" in llm_args:
# Ensure build_config is a BuildConfig object, not a dict
if isinstance(llm_args["build_config"], dict):
llm_args["build_config"] = BuildConfig(**llm_args["build_config"])
for key in [ for key in [
"max_batch_size", "max_batch_size",
"max_num_tokens", "max_num_tokens",

View File

@ -1,4 +1,3 @@
import copy
import json import json
import os import os
import shutil import shutil
@ -530,8 +529,8 @@ class ModelLoader:
logger_debug(f"rank{mpi_rank()} begin to build engine...\n", "green") logger_debug(f"rank{mpi_rank()} begin to build engine...\n", "green")
# avoid the original build_config is modified, avoid the side effect # avoid side effects by copying the original build_config
copied_build_config = copy.deepcopy(self.build_config) copied_build_config = self.build_config.model_copy(deep=True)
copied_build_config.update_kv_cache_type(self._model_info.architecture) copied_build_config.update_kv_cache_type(self._model_info.architecture)
assert self.model is not None, "model is loaded yet." assert self.model is not None, "model is loaded yet."

View File

@ -27,10 +27,10 @@ from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from ..._common import default_net, default_trtnet from ..._common import default_net, default_trtnet
from ..._utils import pad_vocab_size from ..._utils import pad_vocab_size
from ...bindings import KVCacheType
from ...functional import (Tensor, _create_tensor, cast, concat, from ...functional import (Tensor, _create_tensor, cast, concat,
gather_last_token_logits, index_select, shape) gather_last_token_logits, index_select, shape)
from ...layers import AttentionParams, ColumnLinear, SpecDecodingParams from ...layers import AttentionParams, ColumnLinear, SpecDecodingParams
from ...llmapi.kv_cache_type import KVCacheType
from ...module import Module, ModuleList from ...module import Module, ModuleList
from ...plugin import TRT_LLM_PLUGIN_NAMESPACE from ...plugin import TRT_LLM_PLUGIN_NAMESPACE
from ..modeling_utils import QuantConfig from ..modeling_utils import QuantConfig

View File

@ -18,9 +18,9 @@ from typing import List, Optional
import tensorrt as trt import tensorrt as trt
from ..bindings import KVCacheType
from ..functional import Tensor from ..functional import Tensor
from ..layers import MropeParams, SpecDecodingParams from ..layers import MropeParams, SpecDecodingParams
from ..llmapi.kv_cache_type import KVCacheType
from ..mapping import Mapping from ..mapping import Mapping
from ..plugin import current_all_reduce_helper from ..plugin import current_all_reduce_helper

View File

@ -21,7 +21,6 @@ import torch
from tensorrt_llm._common import default_net from tensorrt_llm._common import default_net
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (Conditional, LayerNormPositionType, from tensorrt_llm.functional import (Conditional, LayerNormPositionType,
LayerNormType, MLPType, LayerNormType, MLPType,
PositionEmbeddingType, Tensor, assertion, PositionEmbeddingType, Tensor, assertion,
@ -32,6 +31,7 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
ColumnLinear, Embedding, FusedGatedMLP, ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, GroupNorm, KeyValueCacheParams, GatedMLP, GroupNorm, KeyValueCacheParams,
LayerNorm, LoraParams, RmsNorm) LayerNorm, LoraParams, RmsNorm)
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.lora_helper import (LoraConfig, from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, get_default_trtllm_modules_to_hf_modules,
use_lora) use_lora)

View File

@ -19,7 +19,6 @@ from .._common import default_net
from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch, from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch,
release_gc, str_dtype_to_torch, str_dtype_to_trt, release_gc, str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch) trt_dtype_to_torch)
from ..bindings import KVCacheType
from ..bindings.executor import RuntimeDefaults from ..bindings.executor import RuntimeDefaults
from ..functional import (PositionEmbeddingType, Tensor, allgather, constant, from ..functional import (PositionEmbeddingType, Tensor, allgather, constant,
cp_split_plugin, gather_last_token_logits, cp_split_plugin, gather_last_token_logits,
@ -31,6 +30,7 @@ from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Dora, Lora from ..layers.lora import Dora, Lora
from ..layers.moe import MOE, MoeOOTB from ..layers.moe import MOE, MoeOOTB
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger from ..logger import logger
from ..mapping import Mapping from ..mapping import Mapping
from ..module import Module, ModuleList from ..module import Module, ModuleList

View File

@ -15,7 +15,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union from typing import List, Optional, Tuple, Type, Union
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AttentionMaskType, PositionEmbeddingType, AttentionMaskType, PositionEmbeddingType,
Tensor, gather_last_token_logits, recv, Tensor, gather_last_token_logits, recv,
@ -28,6 +27,7 @@ from tensorrt_llm.layers.linear import ColumnLinear
from tensorrt_llm.layers.lora import LoraParams from tensorrt_llm.layers.lora import LoraParams
from tensorrt_llm.layers.mlp import GatedMLP from tensorrt_llm.layers.mlp import GatedMLP
from tensorrt_llm.layers.normalization import RmsNorm from tensorrt_llm.layers.normalization import RmsNorm
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import has_safetensors from tensorrt_llm.models.convert_utils import has_safetensors
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM

View File

@ -18,8 +18,8 @@ from collections import OrderedDict
import tensorrt as trt import tensorrt as trt
from tensorrt_llm._common import default_net from tensorrt_llm._common import default_net
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import Tensor, cast, categorical_sample from tensorrt_llm.functional import Tensor, cast, categorical_sample
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.models import LLaMAForCausalLM, QWenForCausalLM from tensorrt_llm.models import LLaMAForCausalLM, QWenForCausalLM
from tensorrt_llm.models.generation_mixin import GenerationMixin from tensorrt_llm.models.generation_mixin import GenerationMixin

View File

@ -31,7 +31,6 @@ from tqdm import tqdm
import tensorrt_llm import tensorrt_llm
from tensorrt_llm._common import default_net from tensorrt_llm._common import default_net
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType, from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType,
PositionEmbeddingType, Tensor, PositionEmbeddingType, Tensor,
constant_to_tensor_) constant_to_tensor_)
@ -41,6 +40,7 @@ from tensorrt_llm.layers.attention import (Attention, AttentionParams,
BertAttention, KeyValueCacheParams, BertAttention, KeyValueCacheParams,
bert_attention, layernorm_map) bert_attention, layernorm_map)
from tensorrt_llm.layers.normalization import RmsNorm from tensorrt_llm.layers.normalization import RmsNorm
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.generation_mixin import GenerationMixin from tensorrt_llm.models.generation_mixin import GenerationMixin
from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat, from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat,

View File

@ -43,8 +43,9 @@ from tensorrt_llm.runtime.redrafter_utils import *
from .._utils import (binding_layer_type_to_str, binding_to_str_dtype, from .._utils import (binding_layer_type_to_str, binding_to_str_dtype,
pad_vocab_size, str_dtype_to_torch, torch_to_numpy, pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
trt_dtype_to_torch) trt_dtype_to_torch)
from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free from ..bindings import ipc_nvls_allocate, ipc_nvls_free
from ..layers import LanguageAdapterConfig from ..layers import LanguageAdapterConfig
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger from ..logger import logger
from ..lora_manager import LoraManager from ..lora_manager import LoraManager
from ..mapping import Mapping from ..mapping import Mapping

View File

@ -25,9 +25,10 @@ import torch
from .. import profiler from .. import profiler
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch from .._utils import mpi_comm, mpi_world_size, numpy_to_torch
from ..bindings import KVCacheType, MpiComm from ..bindings import MpiComm
from ..bindings.executor import Executor from ..bindings.executor import Executor
from ..builder import Engine, EngineConfig, get_engine_version from ..builder import Engine, EngineConfig, get_engine_version
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger from ..logger import logger
from ..mapping import Mapping from ..mapping import Mapping
from ..quantization import QuantMode from ..quantization import QuantMode
@ -86,7 +87,9 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]:
dtype = builder_config['precision'] dtype = builder_config['precision']
tp_size = builder_config['tensor_parallel'] tp_size = builder_config['tensor_parallel']
pp_size = builder_config.get('pipeline_parallel', 1) pp_size = builder_config.get('pipeline_parallel', 1)
kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type')) kv_cache_type = builder_config.get('kv_cache_type')
if kv_cache_type is not None:
kv_cache_type = KVCacheType(kv_cache_type)
world_size = tp_size * pp_size world_size = tp_size * pp_size
assert world_size == mpi_world_size(), \ assert world_size == mpi_world_size(), \
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})' f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})'

View File

@ -23,13 +23,13 @@ import torch
from .. import profiler from .. import profiler
from .._utils import mpi_broadcast from .._utils import mpi_broadcast
from ..bindings import (DataType, GptJsonConfig, KVCacheType, ModelConfig, from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig
WorldConfig)
from ..bindings import executor as trtllm from ..bindings import executor as trtllm
from ..bindings.executor import (DecodingMode, ExternalDraftTokensConfig, from ..bindings.executor import (DecodingMode, ExternalDraftTokensConfig,
OrchestratorConfig, ParallelConfig) OrchestratorConfig, ParallelConfig)
from ..builder import EngineConfig from ..builder import EngineConfig
from ..layers import MropeParams from ..layers import MropeParams
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger from ..logger import logger
from ..mapping import Mapping from ..mapping import Mapping
from .generation import LogitsProcessor, LoraManager from .generation import LogitsProcessor, LoraManager
@ -248,7 +248,8 @@ class ModelRunnerCpp(ModelRunnerMixin):
json_config = GptJsonConfig.parse_file(config_path) json_config = GptJsonConfig.parse_file(config_path)
model_config = json_config.model_config model_config = json_config.model_config
use_kv_cache = model_config.kv_cache_type != KVCacheType.DISABLED use_kv_cache = KVCacheType.from_cpp(
model_config.kv_cache_type) != KVCacheType.DISABLED
if not model_config.use_cross_attention: if not model_config.use_cross_attention:
assert cross_kv_cache_fraction is None, "cross_kv_cache_fraction should only be used with enc-dec models." assert cross_kv_cache_fraction is None, "cross_kv_cache_fraction should only be used with enc-dec models."

View File

@ -671,7 +671,8 @@ class CliFlowAccuracyTestHarness:
f"--max_tokens_in_paged_kv_cache={max_tokens_in_paged_kv_cache}" f"--max_tokens_in_paged_kv_cache={max_tokens_in_paged_kv_cache}"
]) ])
if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.max_num_tokens: if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.model_fields[
"max_num_tokens"].default:
summarize_cmd.append("--enable_chunked_context") summarize_cmd.append("--enable_chunked_context")
if self.extra_summarize_args: if self.extra_summarize_args:

View File

@ -142,14 +142,14 @@ def test_llmapi_build_command_parameters_align(llm_root, llm_venv, engine_dir,
with open(os.path.join(engine_dir, "config.json"), "r") as f: with open(os.path.join(engine_dir, "config.json"), "r") as f:
engine_config = json.load(f) engine_config = json.load(f)
build_cmd_cfg = BuildConfig.from_dict( build_cmd_cfg = BuildConfig(
engine_config["build_config"]).to_dict() **engine_config["build_config"]).model_dump()
with open(os.path.join(tmpdir.name, "config.json"), "r") as f: with open(os.path.join(tmpdir.name, "config.json"), "r") as f:
llm_api_engine_cfg = json.load(f) llm_api_engine_cfg = json.load(f)
build_llmapi_cfg = BuildConfig.from_dict( build_llmapi_cfg = BuildConfig(
llm_api_engine_cfg["build_config"]).to_dict() **llm_api_engine_cfg["build_config"]).model_dump()
assert build_cmd_cfg == build_llmapi_cfg assert build_cmd_cfg == build_llmapi_cfg

View File

@ -1636,7 +1636,7 @@ def get_allowed_models(benchmark_type=None):
if i.benchmark_type == benchmark_type) if i.benchmark_type == benchmark_type)
def get_build_config(model_name, return_dict=True) -> Union[BuildConfig]: def get_build_config(model_name, return_dict=True) -> Union[Dict, BuildConfig]:
if model_name in _allowed_configs: if model_name in _allowed_configs:
cfg = _allowed_configs[model_name].build_config cfg = _allowed_configs[model_name].build_config
return asdict(cfg) if return_dict else cfg return asdict(cfg) if return_dict else cfg

View File

@ -255,7 +255,7 @@ def get_quant_config(quantization: str):
def build_gpt(args): def build_gpt(args):
build_config = get_build_config(args.model) build_config = get_build_config(args.model)
build_config = BuildConfig.from_dict(build_config) build_config = BuildConfig(**build_config)
model_config = get_model_config(args.model) model_config = get_model_config(args.model)
if args.force_num_layer_1: if args.force_num_layer_1:
model_config['num_layers'] = 1 model_config['num_layers'] = 1
@ -1448,7 +1448,7 @@ def enc_dec_build_helper(component, build_config, model_config, args):
def build_enc_dec(args): def build_enc_dec(args):
build_config = get_build_config(args.model) build_config = get_build_config(args.model)
build_config = BuildConfig.from_dict(build_config) build_config = BuildConfig(**build_config)
model_config = get_model_config(args.model) model_config = get_model_config(args.model)
if args.force_num_layer_1: if args.force_num_layer_1:
model_config['num_layers'] = 1 model_config['num_layers'] = 1

View File

@ -9,6 +9,7 @@ import torch
from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly
import tensorrt_llm.bindings as _tb import tensorrt_llm.bindings as _tb
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
@ -85,6 +86,7 @@ def test_model_config():
assert model_config.use_packed_input assert model_config.use_packed_input
assert model_config.kv_cache_type is not None assert model_config.kv_cache_type is not None
# Test with C++ enums directly
for enum_val in [ for enum_val in [
_tb.KVCacheType.CONTINUOUS, _tb.KVCacheType.PAGED, _tb.KVCacheType.CONTINUOUS, _tb.KVCacheType.PAGED,
_tb.KVCacheType.DISABLED _tb.KVCacheType.DISABLED
@ -92,6 +94,17 @@ def test_model_config():
model_config.kv_cache_type = enum_val model_config.kv_cache_type = enum_val
assert model_config.kv_cache_type == enum_val assert model_config.kv_cache_type == enum_val
# Test with Python enums converted to C++
for py_enum in [
KVCacheType.CONTINUOUS, KVCacheType.PAGED, KVCacheType.DISABLED
]:
model_config.kv_cache_type = py_enum.to_cpp()
# Verify it was set correctly by comparing with C++ enum
assert model_config.kv_cache_type == getattr(_tb.KVCacheType,
py_enum.name)
# Also verify round-trip conversion works
assert KVCacheType.from_cpp(model_config.kv_cache_type) == py_enum
assert model_config.tokens_per_block == 64 assert model_config.tokens_per_block == 64
tokens_per_block = 1024 tokens_per_block = 1024
model_config.tokens_per_block = tokens_per_block model_config.tokens_per_block = tokens_per_block

View File

@ -202,7 +202,7 @@ def test_llm_build_config():
# read the build_config and check if the parameters are correctly saved # read the build_config and check if the parameters are correctly saved
engine_config = json.load(f) engine_config = json.load(f)
build_config1 = BuildConfig.from_dict(engine_config["build_config"]) build_config1 = BuildConfig(**engine_config["build_config"])
# Know issue: this will be converted to None after save engine for single-gpu # Know issue: this will be converted to None after save engine for single-gpu
build_config1.plugin_config.nccl_plugin = 'float16' build_config1.plugin_config.nccl_plugin = 'float16'

View File

@ -64,7 +64,7 @@ speculative_config:
dict_content = self._yaml_to_dict(yaml_content) dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path) llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content) dict_content)
llm_args = TrtLlmArgs(**llm_args_dict) llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4 assert llm_args.speculative_config.max_window_size == 4
@ -80,7 +80,7 @@ pytorch_backend_config: # this is deprecated
dict_content = self._yaml_to_dict(yaml_content) dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path) llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content) dict_content)
with pytest.raises(ValueError): with pytest.raises(ValueError):
llm_args = TrtLlmArgs(**llm_args_dict) llm_args = TrtLlmArgs(**llm_args_dict)
@ -96,7 +96,7 @@ build_config:
dict_content = self._yaml_to_dict(yaml_content) dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path) llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content) dict_content)
llm_args = TrtLlmArgs(**llm_args_dict) llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.build_config.max_beam_width == 4 assert llm_args.build_config.max_beam_width == 4
@ -113,7 +113,7 @@ kv_cache_config:
dict_content = self._yaml_to_dict(yaml_content) dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path) llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content) dict_content)
llm_args = TrtLlmArgs(**llm_args_dict) llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.kv_cache_config.enable_block_reuse == True assert llm_args.kv_cache_config.enable_block_reuse == True
@ -131,7 +131,7 @@ max_seq_len: 128
dict_content = self._yaml_to_dict(yaml_content) dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path) llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content) dict_content)
llm_args = TrtLlmArgs(**llm_args_dict) llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.max_batch_size == 16 assert llm_args.max_batch_size == 16
@ -331,7 +331,7 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
lora_config=LoraConfig(lora_ckpt_source='hf'), lora_config=LoraConfig(lora_ckpt_source='hf'),
plugin_config=plugin_config) plugin_config=plugin_config)
extra_llm_args_dict = { extra_llm_args_dict = {
"build_config": build_config.to_dict(), "build_config": build_config.model_dump(mode="json"),
} }
llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict, llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
@ -352,8 +352,9 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}") raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True return True
build_config_dict1 = build_config.to_dict() build_config_dict1 = build_config.model_dump(mode="json")
build_config_dict2 = initialized_llm_args.build_config.to_dict() build_config_dict2 = initialized_llm_args.build_config.model_dump(
mode="json")
check_nested_dict_equality(build_config_dict1, build_config_dict2) check_nested_dict_equality(build_config_dict1, build_config_dict2)
@ -498,11 +499,11 @@ class TestTrtLlmArgs:
max_num_tokens=256, max_num_tokens=256,
) )
args = TrtLlmArgs(model=llama_model_path, build_config=build_config) args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
args_dict = args.to_dict() args_dict = args.model_dump()
new_args = TrtLlmArgs.from_kwargs(**args_dict) new_args = TrtLlmArgs.from_kwargs(**args_dict)
assert new_args.to_dict() == args_dict assert new_args.model_dump() == args_dict
def test_build_config_from_engine(self): def test_build_config_from_engine(self):
build_config = BuildConfig(max_batch_size=8, max_num_tokens=256) build_config = BuildConfig(max_batch_size=8, max_num_tokens=256)
@ -522,6 +523,36 @@ class TestTrtLlmArgs:
assert args.max_num_tokens == 16 assert args.max_num_tokens == 16
assert args.max_batch_size == 4 assert args.max_batch_size == 4
def test_model_dump_does_not_mutate_original(self):
"""Test that model_dump() and update_llm_args_with_extra_dict don't mutate the original."""
# Create args with specific build_config values
build_config = BuildConfig(
max_batch_size=8,
max_num_tokens=256,
)
args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
# Store original values
original_max_batch_size = args.build_config.max_batch_size
original_max_num_tokens = args.build_config.max_num_tokens
# Convert to dict and pass through update_llm_args_with_extra_dict with overrides
args_dict = args.model_dump()
extra_dict = {
"max_batch_size": 128,
"max_num_tokens": 1024,
}
updated_dict = update_llm_args_with_extra_dict(args_dict, extra_dict)
# Verify original args was NOT mutated
assert args.build_config.max_batch_size == original_max_batch_size
assert args.build_config.max_num_tokens == original_max_num_tokens
# Verify updated dict has new values
new_args = TrtLlmArgs(**updated_dict)
assert new_args.build_config.max_batch_size == 128
assert new_args.build_config.max_num_tokens == 1024
class TestStrictBaseModelArbitraryArgs: class TestStrictBaseModelArbitraryArgs:
"""Test that StrictBaseModel prevents arbitrary arguments from being accepted.""" """Test that StrictBaseModel prevents arbitrary arguments from being accepted."""