From a09b38a862a4e6fd509a43d1e8e4a228efbe0707 Mon Sep 17 00:00:00 2001 From: Anish Shanbhag Date: Tue, 28 Oct 2025 09:17:26 -0700 Subject: [PATCH] [TRTLLM-8684][chore] Migrate BuildConfig to Pydantic, add a Python wrapper for KVCacheType enum (#8330) Signed-off-by: Anish Shanbhag --- examples/models/core/llama/summarize_long.py | 4 +- examples/models/core/qwen2audio/run.py | 5 +- examples/models/core/qwenvl/run.py | 5 +- examples/models/core/whisper/run.py | 3 +- tensorrt_llm/_torch/auto_deploy/llm_args.py | 6 +- tensorrt_llm/bench/build/build.py | 4 +- tensorrt_llm/builder.py | 322 ++++++------------ tensorrt_llm/commands/build.py | 170 ++++----- tensorrt_llm/commands/eval.py | 8 +- tensorrt_llm/commands/refit.py | 3 +- tensorrt_llm/commands/serve.py | 52 +-- tensorrt_llm/llmapi/build_cache.py | 2 +- tensorrt_llm/llmapi/kv_cache_type.py | 50 +++ tensorrt_llm/llmapi/llm_args.py | 39 +-- tensorrt_llm/llmapi/llm_utils.py | 5 +- tensorrt_llm/models/eagle/model.py | 2 +- tensorrt_llm/models/generation_mixin.py | 2 +- tensorrt_llm/models/mllama/model.py | 2 +- tensorrt_llm/models/modeling_utils.py | 2 +- tensorrt_llm/models/nemotron_nas/model.py | 2 +- tensorrt_llm/models/redrafter/model.py | 2 +- tensorrt_llm/models/stdit/model.py | 2 +- tensorrt_llm/runtime/generation.py | 3 +- tensorrt_llm/runtime/model_runner.py | 7 +- tensorrt_llm/runtime/model_runner_cpp.py | 7 +- .../defs/accuracy/accuracy_core.py | 3 +- tests/integration/defs/llmapi/test_llm_e2e.py | 8 +- .../integration/defs/perf/allowed_configs.py | 2 +- tests/integration/defs/perf/build.py | 4 +- tests/unittest/bindings/test_bindings_ut.py | 13 + tests/unittest/llmapi/test_llm.py | 2 +- tests/unittest/llmapi/test_llm_args.py | 51 ++- 32 files changed, 363 insertions(+), 429 deletions(-) create mode 100644 tensorrt_llm/llmapi/kv_cache_type.py diff --git a/examples/models/core/llama/summarize_long.py b/examples/models/core/llama/summarize_long.py index 7ec2b954d7..f0adbe8072 100644 --- a/examples/models/core/llama/summarize_long.py +++ b/examples/models/core/llama/summarize_long.py @@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer import tensorrt_llm import tensorrt_llm.profiler as profiler -from tensorrt_llm.bindings import KVCacheType +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.logger import logger from tensorrt_llm.quantization import QuantMode @@ -97,7 +97,7 @@ def TRTLLaMA(args, config): quantization_config = pretrained_config['quantization'] build_config = config['build_config'] - kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type']) + kv_cache_type = KVCacheType(build_config['kv_cache_type']) plugin_config = build_config['plugin_config'] dtype = pretrained_config['dtype'] diff --git a/examples/models/core/qwen2audio/run.py b/examples/models/core/qwen2audio/run.py index 0c72eded66..a0b0a68fb1 100644 --- a/examples/models/core/qwen2audio/run.py +++ b/examples/models/core/qwen2audio/run.py @@ -27,7 +27,7 @@ from utils import add_common_args import tensorrt_llm import tensorrt_llm.profiler as profiler from tensorrt_llm import logger -from tensorrt_llm.bindings import KVCacheType +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner, SamplingConfig, Session, TensorInfo) @@ -122,8 +122,7 @@ class QWenInfer(object): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType.from_string( - config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/examples/models/core/qwenvl/run.py b/examples/models/core/qwenvl/run.py index 7013217429..f1530aee80 100644 --- a/examples/models/core/qwenvl/run.py +++ b/examples/models/core/qwenvl/run.py @@ -25,7 +25,7 @@ from vit_onnx_trt import Preprocss import tensorrt_llm import tensorrt_llm.profiler as profiler from tensorrt_llm import logger -from tensorrt_llm.bindings import KVCacheType +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.quantization import QuantMode from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session, TensorInfo) @@ -118,8 +118,7 @@ class QWenInfer(object): num_kv_heads = config["pretrained_config"].get("num_key_value_heads", num_heads) if "kv_cache_type" in config["build_config"]: - kv_cache_type = KVCacheType.from_string( - config["build_config"]["kv_cache_type"]) + kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"]) else: kv_cache_type = KVCacheType.CONTINUOUS diff --git a/examples/models/core/whisper/run.py b/examples/models/core/whisper/run.py index 2e714c1d95..48ea8b10a0 100755 --- a/examples/models/core/whisper/run.py +++ b/examples/models/core/whisper/run.py @@ -33,7 +33,8 @@ import tensorrt_llm import tensorrt_llm.logger as logger from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) -from tensorrt_llm.bindings import GptJsonConfig, KVCacheType +from tensorrt_llm.bindings import GptJsonConfig +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig from tensorrt_llm.runtime.session import Session, TensorInfo diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 1d8b95d4db..6f75150cba 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from tensorrt_llm.models.modeling_utils import QuantConfig from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig -from ...llmapi.utils import get_type_repr from .models import ModelFactory, ModelFactoryRegistry from .utils._config import DynamicYamlMixInForSettings from .utils.logger import ad_logger @@ -318,12 +317,11 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): model_config = _get_config_dict() - build_config: Optional[object] = Field( - default_factory=lambda: BuildConfig(), + build_config: Optional[BuildConfig] = Field( + default_factory=BuildConfig, description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.", exclude_from_json=True, frozen=True, - json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"}, repr=False, ) backend: Literal["_autodeploy"] = Field( diff --git a/tensorrt_llm/bench/build/build.py b/tensorrt_llm/bench/build/build.py index 4de393a5ec..4a9210628a 100644 --- a/tensorrt_llm/bench/build/build.py +++ b/tensorrt_llm/bench/build/build.py @@ -22,8 +22,8 @@ TUNED_QUANTS = { QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.NO_QUANT, None } -DEFAULT_MAX_BATCH_SIZE = BuildConfig.max_batch_size -DEFAULT_MAX_NUM_TOKENS = BuildConfig.max_num_tokens +DEFAULT_MAX_BATCH_SIZE = BuildConfig.model_fields["max_batch_size"].default +DEFAULT_MAX_NUM_TOKENS = BuildConfig.model_fields["max_num_tokens"].default def get_benchmark_engine_settings( diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 177f57a6b3..0fef5d8806 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -12,27 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import dataclasses import json import math import os import shutil import time -from dataclasses import dataclass, field -from functools import cache from pathlib import Path from typing import Dict, Optional, Union import numpy as np import tensorrt as trt +from pydantic import BaseModel, Field from ._common import _is_building, check_max_num_tokens, serialize_engine from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt, to_json_file, trt_gte) -from .bindings import KVCacheType from .functional import PositionEmbeddingType from .graph_rewriting import optimize +from .llmapi.kv_cache_type import KVCacheType from .logger import logger from .lora_helper import LoraConfig from .models import PretrainedConfig, PretrainedModel @@ -46,10 +43,7 @@ from .version import __version__ class ConfigEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, KVCacheType): - # For KVCacheType, convert it to string by split of 'KVCacheType.PAGED'. - return obj.__str__().split('.')[-1] - elif hasattr(obj, 'model_dump'): + if hasattr(obj, 'model_dump'): # Handle Pydantic models (including DecodingBaseConfig and subclasses) return obj.model_dump(mode='json') else: @@ -456,75 +450,112 @@ class Builder(): logger.info(f'Config saved to {config_path}.') -@dataclass -class BuildConfig: +class BuildConfig(BaseModel): """Configuration class for TensorRT LLM engine building parameters. This class contains all the configuration parameters needed to build a TensorRT LLM engine, including sequence length limits, batch sizes, optimization settings, and various features. - - Args: - max_input_len (int): Maximum length of input sequences. Defaults to 1024. - max_seq_len (int, optional): The maximum possible sequence length for a single request, including both input and generated output tokens. Defaults to None. - opt_batch_size (int): Optimal batch size for engine optimization. Defaults to 8. - max_batch_size (int): Maximum batch size the engine can handle. Defaults to 2048. - max_beam_width (int): Maximum beam width for beam search decoding. Defaults to 1. - max_num_tokens (int): Maximum number of batched input tokens after padding is removed in each batch. Defaults to 8192. - opt_num_tokens (int, optional): Optimal number of batched input tokens for engine optimization. Defaults to None. - max_prompt_embedding_table_size (int): Maximum size of prompt embedding table for prompt tuning. Defaults to 0. - kv_cache_type (KVCacheType, optional): Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED. Defaults to None. - gather_context_logits (int): Whether to gather logits during context phase. Defaults to False. - gather_generation_logits (int): Whether to gather logits during generation phase. Defaults to False. - strongly_typed (bool): Whether to use strongly_typed. Defaults to True. - force_num_profiles (int, optional): Force a specific number of optimization profiles. If None, auto-determined. Defaults to None. - profiling_verbosity (str): Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none'). Defaults to 'layer_names_only'. - enable_debug_output (bool): Whether to enable debug output during building. Defaults to False. - max_draft_len (int): Maximum length of draft tokens for speculative decoding. Defaults to 0. - speculative_decoding_mode (SpeculativeDecodingMode): Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.). Defaults to SpeculativeDecodingMode.NONE. - use_refit (bool): Whether to enable engine refitting capabilities. Defaults to False. - input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None. - output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'. - lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig. - weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False. - weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False. - plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig. - use_strip_plan (bool): Whether to use stripped plan for engine building. Defaults to False. - max_encoder_input_len (int): Maximum encoder input length for encoder-decoder models. Defaults to 1024. - dry_run (bool): Whether to perform a dry run without actually building the engine. Defaults to False. - visualize_network (str, optional): Path to save network visualization. If None, no visualization generated. Defaults to None. - monitor_memory (bool): Whether to monitor memory usage during building. Defaults to False. - use_mrope (bool): Whether to use Multi-RoPE (Rotary Position Embedding) optimization. Defaults to False. """ - max_input_len: int = 1024 - max_seq_len: int = None - opt_batch_size: int = 8 - max_batch_size: int = 2048 - max_beam_width: int = 1 - max_num_tokens: int = 8192 - opt_num_tokens: Optional[int] = None - max_prompt_embedding_table_size: int = 0 - kv_cache_type: KVCacheType = None - gather_context_logits: int = False - gather_generation_logits: int = False - strongly_typed: bool = True - force_num_profiles: Optional[int] = None - profiling_verbosity: str = 'layer_names_only' - enable_debug_output: bool = False - max_draft_len: int = 0 - speculative_decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE - use_refit: bool = False - input_timing_cache: str = None - output_timing_cache: str = 'model.cache' - lora_config: LoraConfig = field(default_factory=LoraConfig) - weight_sparsity: bool = False - weight_streaming: bool = False - plugin_config: PluginConfig = field(default_factory=PluginConfig) - use_strip_plan: bool = False - max_encoder_input_len: int = 1024 # for enc-dec DecoderModel - dry_run: bool = False - visualize_network: str = None - monitor_memory: bool = False - use_mrope: bool = False + max_input_len: int = Field(default=1024, + description="Maximum length of input sequences.") + max_seq_len: Optional[int] = Field( + default=None, + description= + "The maximum possible sequence length for a single request, including both input and generated " + "output tokens.") + opt_batch_size: int = Field( + default=8, description="Optimal batch size for engine optimization.") + max_batch_size: int = Field( + default=2048, description="Maximum batch size the engine can handle.") + max_beam_width: int = Field( + default=1, description="Maximum beam width for beam search decoding.") + max_num_tokens: int = Field( + default=8192, + description="Maximum number of batched input tokens after padding is " + "removed in each batch.") + opt_num_tokens: Optional[int] = Field( + default=None, + description= + "Optimal number of batched input tokens for engine optimization.") + max_prompt_embedding_table_size: int = Field( + default=0, + description="Maximum size of prompt embedding table for prompt tuning.") + kv_cache_type: Optional[KVCacheType] = Field( + default=None, + description= + "Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED." + ) + gather_context_logits: bool = Field( + default=False, + description="Whether to gather logits during context phase.") + gather_generation_logits: bool = Field( + default=False, + description="Whether to gather logits during generation phase.") + strongly_typed: bool = Field(default=True, + description="Whether to use strongly_typed.") + force_num_profiles: Optional[int] = Field( + default=None, + description= + "Force a specific number of optimization profiles. If None, auto-determined." + ) + profiling_verbosity: str = Field( + default='layer_names_only', + description= + "Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none')." + ) + enable_debug_output: bool = Field( + default=False, + description="Whether to enable debug output during building.") + max_draft_len: int = Field( + default=0, + description="Maximum length of draft tokens for speculative decoding.") + speculative_decoding_mode: SpeculativeDecodingMode = Field( + default=SpeculativeDecodingMode.NONE, + description="Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.)." + ) + use_refit: bool = Field( + default=False, + description="Whether to enable engine refitting capabilities.") + input_timing_cache: Optional[str] = Field( + default=None, + description= + "Path to input timing cache file. If None, no input cache used.") + output_timing_cache: str = Field( + default='model.cache', description="Path to output timing cache file.") + lora_config: LoraConfig = Field( + default_factory=LoraConfig, + description="Configuration for LoRA (Low-Rank Adaptation) fine-tuning.") + weight_sparsity: bool = Field( + default=False, + description="Whether to enable weight sparsity optimization.") + weight_streaming: bool = Field( + default=False, + description="Whether to enable weight streaming for large models.") + plugin_config: PluginConfig = Field( + default_factory=PluginConfig, + description="Configuration for TensorRT LLM plugins.") + use_strip_plan: bool = Field( + default=False, + description="Whether to use stripped plan for engine building.") + max_encoder_input_len: int = Field( + default=1024, + description="Maximum encoder input length for encoder-decoder models.") + dry_run: bool = Field( + default=False, + description= + "Whether to perform a dry run without actually building the engine.") + visualize_network: Optional[str] = Field( + default=None, + description= + "Path to save network visualization. If None, no visualization generated." + ) + monitor_memory: bool = Field( + default=False, + description="Whether to monitor memory usage during building.") + use_mrope: bool = Field( + default=False, + description= + "Whether to use Multi-RoPE (Rotary Position Embedding) optimization.") # Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future), # we need to handle it given model architecture. @@ -574,144 +605,10 @@ class BuildConfig: override_attri('paged_state', False) @classmethod - @cache - def get_build_config_defaults(cls): - return { - field.name: field.default - for field in dataclasses.fields(cls) - if field.default is not dataclasses.MISSING - } - - @classmethod - def from_dict(cls, config, plugin_config=None): - config = copy.deepcopy( - config - ) # it just does not make sense to change the input arg `config` - - defaults = cls.get_build_config_defaults() - max_input_len = config.pop('max_input_len', - defaults.get('max_input_len')) - max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len')) - max_batch_size = config.pop('max_batch_size', - defaults.get('max_batch_size')) - max_beam_width = config.pop('max_beam_width', - defaults.get('max_beam_width')) - max_num_tokens = config.pop('max_num_tokens', - defaults.get('max_num_tokens')) - opt_num_tokens = config.pop('opt_num_tokens', - defaults.get('opt_num_tokens')) - opt_batch_size = config.pop('opt_batch_size', - defaults.get('opt_batch_size')) - max_prompt_embedding_table_size = config.pop( - 'max_prompt_embedding_table_size', - defaults.get('max_prompt_embedding_table_size')) - - if "kv_cache_type" in config and config["kv_cache_type"] is not None: - kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type')) - else: - kv_cache_type = None - gather_context_logits = config.pop( - 'gather_context_logits', defaults.get('gather_context_logits')) - gather_generation_logits = config.pop( - 'gather_generation_logits', - defaults.get('gather_generation_logits')) - strongly_typed = config.pop('strongly_typed', - defaults.get('strongly_typed')) - force_num_profiles = config.pop('force_num_profiles', - defaults.get('force_num_profiles')) - weight_sparsity = config.pop('weight_sparsity', - defaults.get('weight_sparsity')) - profiling_verbosity = config.pop('profiling_verbosity', - defaults.get('profiling_verbosity')) - enable_debug_output = config.pop('enable_debug_output', - defaults.get('enable_debug_output')) - max_draft_len = config.pop('max_draft_len', - defaults.get('max_draft_len')) - speculative_decoding_mode = config.pop( - 'speculative_decoding_mode', - defaults.get('speculative_decoding_mode')) - use_refit = config.pop('use_refit', defaults.get('use_refit')) - input_timing_cache = config.pop('input_timing_cache', - defaults.get('input_timing_cache')) - output_timing_cache = config.pop('output_timing_cache', - defaults.get('output_timing_cache')) - lora_config = LoraConfig(**config.get('lora_config', {})) - max_encoder_input_len = config.pop( - 'max_encoder_input_len', defaults.get('max_encoder_input_len')) - weight_streaming = config.pop('weight_streaming', - defaults.get('weight_streaming')) - use_strip_plan = config.pop('use_strip_plan', - defaults.get('use_strip_plan')) - - if plugin_config is None: - plugin_config = PluginConfig() - if "plugin_config" in config.keys(): - plugin_config = plugin_config.model_copy( - update=config["plugin_config"], deep=True) - - dry_run = config.pop('dry_run', defaults.get('dry_run')) - visualize_network = config.pop('visualize_network', - defaults.get('visualize_network')) - monitor_memory = config.pop('monitor_memory', - defaults.get('monitor_memory')) - use_mrope = config.pop('use_mrope', defaults.get('use_mrope')) - - return cls( - max_input_len=max_input_len, - max_seq_len=max_seq_len, - max_batch_size=max_batch_size, - max_beam_width=max_beam_width, - max_num_tokens=max_num_tokens, - opt_num_tokens=opt_num_tokens, - opt_batch_size=opt_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - kv_cache_type=kv_cache_type, - gather_context_logits=gather_context_logits, - gather_generation_logits=gather_generation_logits, - strongly_typed=strongly_typed, - force_num_profiles=force_num_profiles, - profiling_verbosity=profiling_verbosity, - enable_debug_output=enable_debug_output, - max_draft_len=max_draft_len, - speculative_decoding_mode=speculative_decoding_mode, - use_refit=use_refit, - input_timing_cache=input_timing_cache, - output_timing_cache=output_timing_cache, - lora_config=lora_config, - use_strip_plan=use_strip_plan, - max_encoder_input_len=max_encoder_input_len, - weight_sparsity=weight_sparsity, - weight_streaming=weight_streaming, - plugin_config=plugin_config, - dry_run=dry_run, - visualize_network=visualize_network, - monitor_memory=monitor_memory, - use_mrope=use_mrope) - - @classmethod - def from_json_file(cls, config_file, plugin_config=None): + def from_json_file(cls, config_file): with open(config_file) as f: config = json.load(f) - return BuildConfig.from_dict(config, plugin_config=plugin_config) - - def to_dict(self): - output = copy.deepcopy(self.__dict__) - # the enum KVCacheType cannot be converted automatically - if output.get('kv_cache_type', None) is not None: - output['kv_cache_type'] = str(output['kv_cache_type'].name) - output['plugin_config'] = output['plugin_config'].model_dump() - output['lora_config'] = output['lora_config'].model_dump() - return output - - def update_from_dict(self, config: dict): - for name, value in config.items(): - if not hasattr(self, name): - raise AttributeError( - f"{self.__class__} object has no attribute {name}") - setattr(self, name, value) - - def update(self, **kwargs): - self.update_from_dict(kwargs) + return BuildConfig(**config) class EngineConfig: @@ -731,11 +628,10 @@ class EngineConfig: def from_json_str(cls, config_str): config = json.loads(config_str) return cls(PretrainedConfig.from_dict(config['pretrained_config']), - BuildConfig.from_dict(config['build_config']), - config['version']) + BuildConfig(**config['build_config']), config['version']) def to_dict(self): - build_config = self.build_config.to_dict() + build_config = self.build_config.model_dump(mode="json") build_config.pop('dry_run', None) # Not an Engine Characteristic build_config.pop('visualize_network', None) # Not an Engine Characteristic @@ -1081,7 +977,7 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine: ''' tic = time.time() # avoid changing the input config - build_config = copy.deepcopy(build_config) + build_config = build_config.model_copy(deep=True) build_config.plugin_config.dtype = model.config.dtype build_config.update_kv_cache_type(model.config.architecture) diff --git a/tensorrt_llm/commands/build.py b/tensorrt_llm/commands/build.py index 85176b2232..e246f1c8d4 100644 --- a/tensorrt_llm/commands/build.py +++ b/tensorrt_llm/commands/build.py @@ -26,8 +26,8 @@ import torch from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier, mpi_comm, mpi_rank, mpi_world_size) -from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.builder import BuildConfig, Engine, build +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager @@ -37,23 +37,6 @@ from tensorrt_llm.plugin import PluginConfig, add_plugin_argument from tensorrt_llm.quantization.mode import QuantAlgo -def enum_type(enum_class): - - def parse_enum(value): - if isinstance(value, enum_class): - return value - - if isinstance(value, str): - return enum_class.from_string(value) - - valid_values = [e.name for e in enum_class] - raise argparse.ArgumentTypeError( - f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}" - ) - - return parse_enum - - def parse_arguments(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -92,29 +75,30 @@ def parse_arguments(): parser.add_argument( '--max_batch_size', type=int, - default=BuildConfig.max_batch_size, + default=BuildConfig.model_fields["max_batch_size"].default, help="Maximum number of requests that the engine can schedule.") - parser.add_argument('--max_input_len', - type=int, - default=BuildConfig.max_input_len, - help="Maximum input length of one request.") + parser.add_argument( + '--max_input_len', + type=int, + default=BuildConfig.model_fields["max_input_len"].default, + help="Maximum input length of one request.") parser.add_argument( '--max_seq_len', '--max_decoder_seq_len', dest='max_seq_len', type=int, - default=BuildConfig.max_seq_len, + default=BuildConfig.model_fields["max_seq_len"].default, help="Maximum total length of one request, including prompt and outputs. " "If unspecified, the value is deduced from the model config.") parser.add_argument( '--max_beam_width', type=int, - default=BuildConfig.max_beam_width, + default=BuildConfig.model_fields["max_beam_width"].default, help="Maximum number of beams for beam search decoding.") parser.add_argument( '--max_num_tokens', type=int, - default=BuildConfig.max_num_tokens, + default=BuildConfig.model_fields["max_num_tokens"].default, help= "Maximum number of batched input tokens after padding is removed in each batch. " "Currently, the input padding is removed by default; " @@ -123,7 +107,7 @@ def parse_arguments(): parser.add_argument( '--opt_num_tokens', type=int, - default=BuildConfig.opt_num_tokens, + default=BuildConfig.model_fields["opt_num_tokens"].default, help= "Optimal number of batched input tokens after padding is removed in each batch " "It equals to ``max_batch_size * max_beam_width`` by default, set this " @@ -132,7 +116,7 @@ def parse_arguments(): parser.add_argument( '--max_encoder_input_len', type=int, - default=BuildConfig.max_encoder_input_len, + default=BuildConfig.model_fields["max_encoder_input_len"].default, help="Maximum encoder input length for enc-dec models. " "Set ``max_input_len`` to 1 to start generation from decoder_start_token_id of length 1." ) @@ -140,14 +124,15 @@ def parse_arguments(): '--max_prompt_embedding_table_size', '--max_multimodal_len', type=int, - default=BuildConfig.max_prompt_embedding_table_size, + default=BuildConfig.model_fields["max_prompt_embedding_table_size"]. + default, help= "Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. " "Setting a value > 0 enables prompt tuning or multimodal input.") parser.add_argument( '--kv_cache_type', default=argparse.SUPPRESS, - type=enum_type(KVCacheType), + type=KVCacheType, help= "Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed." ) @@ -156,42 +141,44 @@ def parse_arguments(): type=str, default=argparse.SUPPRESS, help= - "Deprecated. Enabling this option is equvilient to ``--kv_cache_type paged`` for transformer based models." + "Deprecated. Enabling this option is equivalent to ``--kv_cache_type paged`` for transformer based models." ) parser.add_argument( '--input_timing_cache', type=str, - default=BuildConfig.input_timing_cache, + default=BuildConfig.model_fields["input_timing_cache"].default, help= "The file path to read the timing cache. This option is ignored if the file does not exist." ) - parser.add_argument('--output_timing_cache', - type=str, - default=BuildConfig.output_timing_cache, - help="The file path to write the timing cache.") + parser.add_argument( + '--output_timing_cache', + type=str, + default=BuildConfig.model_fields["output_timing_cache"].default, + help="The file path to write the timing cache.") parser.add_argument( '--profiling_verbosity', type=str, - default=BuildConfig.profiling_verbosity, + default=BuildConfig.model_fields["profiling_verbosity"].default, choices=['layer_names_only', 'detailed', 'none'], help= "The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters." ) parser.add_argument( '--strip_plan', - default=BuildConfig.use_strip_plan, + default=BuildConfig.model_fields["use_strip_plan"].default, action='store_true', help= "Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time." ) - parser.add_argument('--weight_sparsity', - default=BuildConfig.weight_sparsity, - action='store_true', - help="Enable weight sparsity.") + parser.add_argument( + '--weight_sparsity', + default=BuildConfig.model_fields["weight_sparsity"].default, + action='store_true', + help="Enable weight sparsity.") parser.add_argument( '--weight_streaming', - default=BuildConfig.weight_streaming, + default=BuildConfig.model_fields["weight_streaming"].default, action='store_true', help= "Enable offloading weights to CPU and streaming loading at runtime.", @@ -213,10 +200,11 @@ def parse_arguments(): default='info', choices=severity_map.keys(), help="The logging level.") - parser.add_argument('--enable_debug_output', - default=BuildConfig.enable_debug_output, - action='store_true', - help="Enable debug output.") + parser.add_argument( + '--enable_debug_output', + default=BuildConfig.model_fields["enable_debug_output"].default, + action='store_true', + help="Enable debug output.") parser.add_argument( '--visualize_network', type=str, @@ -226,7 +214,7 @@ def parse_arguments(): ) parser.add_argument( '--dry_run', - default=BuildConfig.dry_run, + default=BuildConfig.model_fields["dry_run"].default, action='store_true', help= "Run through the build process except the actual Engine build for debugging." @@ -519,65 +507,37 @@ def main(): f"Overriding # of builder profiles <= {force_num_profiles_from_env}." ) - build_config = BuildConfig.from_dict( - { - 'max_input_len': - args.max_input_len, - 'max_seq_len': - args.max_seq_len, - 'max_batch_size': - args.max_batch_size, - 'max_beam_width': - args.max_beam_width, - 'max_num_tokens': - args.max_num_tokens, - 'opt_num_tokens': - args.opt_num_tokens, - 'max_prompt_embedding_table_size': - args.max_prompt_embedding_table_size, - 'gather_context_logits': - args.gather_context_logits, - 'gather_generation_logits': - args.gather_generation_logits, - 'strongly_typed': - True, - 'force_num_profiles': - force_num_profiles_from_env, - 'weight_sparsity': - args.weight_sparsity, - 'profiling_verbosity': - args.profiling_verbosity, - 'enable_debug_output': - args.enable_debug_output, - 'max_draft_len': - args.max_draft_len, - 'speculative_decoding_mode': - speculative_decoding_mode, - 'input_timing_cache': - args.input_timing_cache, - 'output_timing_cache': - args.output_timing_cache, - 'dry_run': - args.dry_run, - 'visualize_network': - args.visualize_network, - 'max_encoder_input_len': - args.max_encoder_input_len, - 'weight_streaming': - args.weight_streaming, - 'monitor_memory': - args.monitor_memory, - 'use_mrope': - (True if model_config.qwen_type == "qwen2_vl" else False) - if hasattr(model_config, "qwen_type") else False - }, + build_config = BuildConfig( + max_input_len=args.max_input_len, + max_seq_len=args.max_seq_len, + max_batch_size=args.max_batch_size, + max_beam_width=args.max_beam_width, + max_num_tokens=args.max_num_tokens, + opt_num_tokens=args.opt_num_tokens, + max_prompt_embedding_table_size=args. + max_prompt_embedding_table_size, + kv_cache_type=getattr(args, "kv_cache_type", None), + gather_context_logits=args.gather_context_logits, + gather_generation_logits=args.gather_generation_logits, + strongly_typed=True, + force_num_profiles=force_num_profiles_from_env, + weight_sparsity=args.weight_sparsity, + profiling_verbosity=args.profiling_verbosity, + enable_debug_output=args.enable_debug_output, + max_draft_len=args.max_draft_len, + speculative_decoding_mode=speculative_decoding_mode, + input_timing_cache=args.input_timing_cache, + output_timing_cache=args.output_timing_cache, + dry_run=args.dry_run, + visualize_network=args.visualize_network, + max_encoder_input_len=args.max_encoder_input_len, + weight_streaming=args.weight_streaming, + monitor_memory=args.monitor_memory, + use_mrope=getattr(model_config, "qwen_type", None) == "qwen2_vl", plugin_config=plugin_config) - - if hasattr(args, 'kv_cache_type'): - build_config.update_from_dict({'kv_cache_type': args.kv_cache_type}) else: - build_config = BuildConfig.from_json_file(args.build_config, - plugin_config=plugin_config) + build_config = BuildConfig.from_json_file(args.build_config) + build_config.plugin_config = plugin_config parallel_build(model_config, ckpt_dir, build_config, args.output_dir, workers, args.log_level, model_cls, **kwargs) diff --git a/tensorrt_llm/commands/eval.py b/tensorrt_llm/commands/eval.py index 937edb5af8..8a0b4f5826 100644 --- a/tensorrt_llm/commands/eval.py +++ b/tensorrt_llm/commands/eval.py @@ -50,23 +50,23 @@ from ..logger import logger, severity_map help="The logging level.") @click.option("--max_beam_width", type=int, - default=BuildConfig.max_beam_width, + default=BuildConfig.model_fields["max_beam_width"].default, help="Maximum number of beams for beam search decoding.") @click.option("--max_batch_size", type=int, - default=BuildConfig.max_batch_size, + default=BuildConfig.model_fields["max_batch_size"].default, help="Maximum number of requests that the engine can schedule.") @click.option( "--max_num_tokens", type=int, - default=BuildConfig.max_num_tokens, + default=BuildConfig.model_fields["max_num_tokens"].default, help= "Maximum number of batched input tokens after padding is removed in each batch." ) @click.option( "--max_seq_len", type=int, - default=BuildConfig.max_seq_len, + default=BuildConfig.model_fields["max_seq_len"].default, help="Maximum total length of one request, including prompt and outputs. " "If unspecified, the value is deduced from the model config.") @click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.') diff --git a/tensorrt_llm/commands/refit.py b/tensorrt_llm/commands/refit.py index 218789a021..2243e72ed5 100644 --- a/tensorrt_llm/commands/refit.py +++ b/tensorrt_llm/commands/refit.py @@ -2,7 +2,6 @@ Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint. ''' import argparse -import copy import json import os import re @@ -57,7 +56,7 @@ def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str, # There are weights preprocess during optimize model. tik = time.time() - build_config = copy.deepcopy(engine_config.build_config) + build_config = engine_config.build_config.model_copy(deep=True) optimize_model_with_config(model, build_config) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 8e43937a1d..6c1c17cf9d 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -75,25 +75,29 @@ def _signal_handler_cleanup_child(signum, frame): sys.exit(128 + signum) -def get_llm_args(model: str, - tokenizer: Optional[str] = None, - backend: str = "pytorch", - max_beam_width: int = BuildConfig.max_beam_width, - max_batch_size: int = BuildConfig.max_batch_size, - max_num_tokens: int = BuildConfig.max_num_tokens, - max_seq_len: int = BuildConfig.max_seq_len, - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, - moe_expert_parallel_size: Optional[int] = None, - gpus_per_node: Optional[int] = None, - free_gpu_memory_fraction: float = 0.9, - num_postprocess_workers: int = 0, - trust_remote_code: bool = False, - reasoning_parser: Optional[str] = None, - fail_fast_on_attention_window_too_large: bool = False, - otlp_traces_endpoint: Optional[str] = None, - enable_chunked_prefill: bool = False, - **llm_args_extra_dict: Any): +def get_llm_args( + model: str, + tokenizer: Optional[str] = None, + backend: str = "pytorch", + max_beam_width: int = BuildConfig.model_fields["max_beam_width"]. + default, + max_batch_size: int = BuildConfig.model_fields["max_batch_size"]. + default, + max_num_tokens: int = BuildConfig.model_fields["max_num_tokens"]. + default, + max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + moe_expert_parallel_size: Optional[int] = None, + gpus_per_node: Optional[int] = None, + free_gpu_memory_fraction: float = 0.9, + num_postprocess_workers: int = 0, + trust_remote_code: bool = False, + reasoning_parser: Optional[str] = None, + fail_fast_on_attention_window_too_large: bool = False, + otlp_traces_endpoint: Optional[str] = None, + enable_chunked_prefill: bool = False, + **llm_args_extra_dict: Any): if gpus_per_node is None: gpus_per_node = device_count() @@ -242,23 +246,23 @@ class ChoiceWithAlias(click.Choice): help="The logging level.") @click.option("--max_beam_width", type=int, - default=BuildConfig.max_beam_width, + default=BuildConfig.model_fields["max_beam_width"].default, help="Maximum number of beams for beam search decoding.") @click.option("--max_batch_size", type=int, - default=BuildConfig.max_batch_size, + default=BuildConfig.model_fields["max_batch_size"].default, help="Maximum number of requests that the engine can schedule.") @click.option( "--max_num_tokens", type=int, - default=BuildConfig.max_num_tokens, + default=BuildConfig.model_fields["max_num_tokens"].default, help= "Maximum number of batched input tokens after padding is removed in each batch." ) @click.option( "--max_seq_len", type=int, - default=BuildConfig.max_seq_len, + default=BuildConfig.model_fields["max_seq_len"].default, help="Maximum total length of one request, including prompt and outputs. " "If unspecified, the value is deduced from the model config.") @click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.') @@ -436,7 +440,7 @@ def serve( help="The logging level.") @click.option("--max_batch_size", type=int, - default=BuildConfig.max_batch_size, + default=BuildConfig.model_fields["max_batch_size"].default, help="Maximum number of requests that the engine can schedule.") @click.option( "--max_num_tokens", diff --git a/tensorrt_llm/llmapi/build_cache.py b/tensorrt_llm/llmapi/build_cache.py index 58ecd1be03..a666ab1fff 100644 --- a/tensorrt_llm/llmapi/build_cache.py +++ b/tensorrt_llm/llmapi/build_cache.py @@ -104,7 +104,7 @@ class BuildCache: Get the build step for engine building. ''' build_config_str = json.dumps(self.prune_build_config_for_cache_key( - build_config.to_dict()), + build_config.model_dump(mode="json")), sort_keys=True) kwargs_str = json.dumps(kwargs, sort_keys=True) diff --git a/tensorrt_llm/llmapi/kv_cache_type.py b/tensorrt_llm/llmapi/kv_cache_type.py new file mode 100644 index 0000000000..eee01260ec --- /dev/null +++ b/tensorrt_llm/llmapi/kv_cache_type.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import tensorrt_llm.bindings as _bindings + + +class KVCacheType(str, Enum): + """Python enum wrapper for KVCacheType. + + This is a pure Python enum that mirrors the C++ KVCacheType enum exposed + through pybind11. + """ + CONTINUOUS = "continuous" + PAGED = "paged" + DISABLED = "disabled" + + @classmethod + def _missing_(cls, value): + """Allow case-insensitive string values to be converted to enum members.""" + if isinstance(value, str): + for member in cls: + if member.value.lower() == value.lower(): + return member + return None + + def to_cpp(self) -> '_bindings.KVCacheType': + import tensorrt_llm.bindings as _bindings + return getattr(_bindings.KVCacheType, self.name) + + @classmethod + def from_cpp(cls, cpp_enum) -> 'KVCacheType': + # C++ enum's __str__ returns "KVCacheType.PAGED", extract the name + name = str(cpp_enum).split('.')[-1] + return cls(name) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7b907ef827..4a01844fb8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1,5 +1,4 @@ import ast -import copy import functools import json import math @@ -1764,17 +1763,6 @@ class BaseLlmArgs(StrictBaseModel): ret = cls(**kwargs) return ret - def to_dict(self) -> dict: - """Dump `LlmArgs` instance to a dict. - - Returns: - dict: The dict that contains all fields of the `LlmArgs` instance. - """ - model_dict = self.model_dump(mode='json') - # TODO: the BuildConfig.to_dict and from_dict don't work well with pydantic - model_dict['build_config'] = copy.deepcopy(self.build_config) - return model_dict - @staticmethod def _check_consistency(kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: # max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during @@ -1919,10 +1907,6 @@ class BaseLlmArgs(StrictBaseModel): if self.max_input_len: kwargs["max_input_len"] = self.max_input_len self.build_config = BuildConfig(**kwargs) - else: - assert isinstance( - build_config, - BuildConfig), f"build_config is not initialized: {build_config}" return self @model_validator(mode="after") @@ -2001,7 +1985,7 @@ class BaseLlmArgs(StrictBaseModel): # TODO: remove the checker when manage weights support all data types if is_trt_llm_args and self.fast_build and (self.quant_config.quant_algo is QuantAlgo.FP8): - self._update_plugin_config("manage_weights", True) + self.build_config.plugin_config.manage_weights = True if self.parallel_config.world_size == 1 and self.build_config: self.build_config.plugin_config.nccl_plugin = None @@ -2166,9 +2150,6 @@ class BaseLlmArgs(StrictBaseModel): "while LoRA prefetch is not supported") return self - def _update_plugin_config(self, key: str, value: Any): - setattr(self.build_config.plugin_config, key, value) - def _load_config_from_engine(self, engine_dir: Path): engine_config = EngineConfig.from_json_file(engine_dir / "config.json") self._pretrained_config = engine_config.pretrained_config @@ -2271,10 +2252,8 @@ class TrtLlmArgs(BaseLlmArgs): fast_build: bool = Field(default=False, description="Enable fast build.") # BuildConfig is introduced to give users a familiar interface to configure the model building. - build_config: Optional[object] = Field( - default=None, - description="Build config.", - json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"}) + build_config: Optional[BuildConfig] = Field(default=None, + description="Build config.") # Prompt adapter arguments enable_prompt_adapter: bool = Field(default=False, @@ -2405,11 +2384,10 @@ class TorchCompileConfig(StrictBaseModel): class TorchLlmArgs(BaseLlmArgs): # Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs - build_config: Optional[object] = Field( + build_config: Optional[BuildConfig] = Field( default=None, description="Build config.", exclude_from_json=True, - json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"}, status="deprecated", ) @@ -2911,10 +2889,7 @@ def update_llm_args_with_extra_dict( for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: # Some fields need to be converted manually. - if field_name in [ - "speculative_config", "build_config", - "sparse_attention_config" - ]: + if field_name in ["speculative_config", "sparse_attention_config"]: llm_args_dict[field_name] = field_type.from_dict( llm_args_dict[field_name]) else: @@ -2928,6 +2903,10 @@ def update_llm_args_with_extra_dict( # For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch # backend, overwriting the knobs there since build_config always has the highest priority if "build_config" in llm_args: + # Ensure build_config is a BuildConfig object, not a dict + if isinstance(llm_args["build_config"], dict): + llm_args["build_config"] = BuildConfig(**llm_args["build_config"]) + for key in [ "max_batch_size", "max_num_tokens", diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index efda96c2a7..ecd0e5bfc1 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -1,4 +1,3 @@ -import copy import json import os import shutil @@ -530,8 +529,8 @@ class ModelLoader: logger_debug(f"rank{mpi_rank()} begin to build engine...\n", "green") - # avoid the original build_config is modified, avoid the side effect - copied_build_config = copy.deepcopy(self.build_config) + # avoid side effects by copying the original build_config + copied_build_config = self.build_config.model_copy(deep=True) copied_build_config.update_kv_cache_type(self._model_info.architecture) assert self.model is not None, "model is loaded yet." diff --git a/tensorrt_llm/models/eagle/model.py b/tensorrt_llm/models/eagle/model.py index e6edc7c676..779f397d70 100644 --- a/tensorrt_llm/models/eagle/model.py +++ b/tensorrt_llm/models/eagle/model.py @@ -27,10 +27,10 @@ from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from ..._common import default_net, default_trtnet from ..._utils import pad_vocab_size -from ...bindings import KVCacheType from ...functional import (Tensor, _create_tensor, cast, concat, gather_last_token_logits, index_select, shape) from ...layers import AttentionParams, ColumnLinear, SpecDecodingParams +from ...llmapi.kv_cache_type import KVCacheType from ...module import Module, ModuleList from ...plugin import TRT_LLM_PLUGIN_NAMESPACE from ..modeling_utils import QuantConfig diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index f97b8d436b..5f590f1ee4 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -18,9 +18,9 @@ from typing import List, Optional import tensorrt as trt -from ..bindings import KVCacheType from ..functional import Tensor from ..layers import MropeParams, SpecDecodingParams +from ..llmapi.kv_cache_type import KVCacheType from ..mapping import Mapping from ..plugin import current_all_reduce_helper diff --git a/tensorrt_llm/models/mllama/model.py b/tensorrt_llm/models/mllama/model.py index 95a261350b..a610a48095 100644 --- a/tensorrt_llm/models/mllama/model.py +++ b/tensorrt_llm/models/mllama/model.py @@ -21,7 +21,6 @@ import torch from tensorrt_llm._common import default_net from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch -from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import (Conditional, LayerNormPositionType, LayerNormType, MLPType, PositionEmbeddingType, Tensor, assertion, @@ -32,6 +31,7 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams, ColumnLinear, Embedding, FusedGatedMLP, GatedMLP, GroupNorm, KeyValueCacheParams, LayerNorm, LoraParams, RmsNorm) +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules, use_lora) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 15e39b8811..03c1ee60ae 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -19,7 +19,6 @@ from .._common import default_net from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch, release_gc, str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) -from ..bindings import KVCacheType from ..bindings.executor import RuntimeDefaults from ..functional import (PositionEmbeddingType, Tensor, allgather, constant, cp_split_plugin, gather_last_token_logits, @@ -31,6 +30,7 @@ from ..layers.attention import Attention, BertAttention from ..layers.linear import ColumnLinear, Linear, RowLinear from ..layers.lora import Dora, Lora from ..layers.moe import MOE, MoeOOTB +from ..llmapi.kv_cache_type import KVCacheType from ..logger import logger from ..mapping import Mapping from ..module import Module, ModuleList diff --git a/tensorrt_llm/models/nemotron_nas/model.py b/tensorrt_llm/models/nemotron_nas/model.py index a4e567f739..f6562d3a76 100644 --- a/tensorrt_llm/models/nemotron_nas/model.py +++ b/tensorrt_llm/models/nemotron_nas/model.py @@ -15,7 +15,6 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Type, Union -from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AttentionMaskType, PositionEmbeddingType, Tensor, gather_last_token_logits, recv, @@ -28,6 +27,7 @@ from tensorrt_llm.layers.linear import ColumnLinear from tensorrt_llm.layers.lora import LoraParams from tensorrt_llm.layers.mlp import GatedMLP from tensorrt_llm.layers.normalization import RmsNorm +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.convert_utils import has_safetensors from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM diff --git a/tensorrt_llm/models/redrafter/model.py b/tensorrt_llm/models/redrafter/model.py index 4f9f139f05..84c78cc798 100644 --- a/tensorrt_llm/models/redrafter/model.py +++ b/tensorrt_llm/models/redrafter/model.py @@ -18,8 +18,8 @@ from collections import OrderedDict import tensorrt as trt from tensorrt_llm._common import default_net -from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import Tensor, cast, categorical_sample +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.models import LLaMAForCausalLM, QWenForCausalLM from tensorrt_llm.models.generation_mixin import GenerationMixin diff --git a/tensorrt_llm/models/stdit/model.py b/tensorrt_llm/models/stdit/model.py index 7e2cc5bdce..938f801868 100644 --- a/tensorrt_llm/models/stdit/model.py +++ b/tensorrt_llm/models/stdit/model.py @@ -31,7 +31,6 @@ from tqdm import tqdm import tensorrt_llm from tensorrt_llm._common import default_net from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str -from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType, PositionEmbeddingType, Tensor, constant_to_tensor_) @@ -41,6 +40,7 @@ from tensorrt_llm.layers.attention import (Attention, AttentionParams, BertAttention, KeyValueCacheParams, bert_attention, layernorm_map) from tensorrt_llm.layers.normalization import RmsNorm +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.generation_mixin import GenerationMixin from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat, diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index cb63365677..6a49587d30 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -43,8 +43,9 @@ from tensorrt_llm.runtime.redrafter_utils import * from .._utils import (binding_layer_type_to_str, binding_to_str_dtype, pad_vocab_size, str_dtype_to_torch, torch_to_numpy, trt_dtype_to_torch) -from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free +from ..bindings import ipc_nvls_allocate, ipc_nvls_free from ..layers import LanguageAdapterConfig +from ..llmapi.kv_cache_type import KVCacheType from ..logger import logger from ..lora_manager import LoraManager from ..mapping import Mapping diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index 94965e66d2..308af9b012 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -25,9 +25,10 @@ import torch from .. import profiler from .._utils import mpi_comm, mpi_world_size, numpy_to_torch -from ..bindings import KVCacheType, MpiComm +from ..bindings import MpiComm from ..bindings.executor import Executor from ..builder import Engine, EngineConfig, get_engine_version +from ..llmapi.kv_cache_type import KVCacheType from ..logger import logger from ..mapping import Mapping from ..quantization import QuantMode @@ -86,7 +87,9 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]: dtype = builder_config['precision'] tp_size = builder_config['tensor_parallel'] pp_size = builder_config.get('pipeline_parallel', 1) - kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type')) + kv_cache_type = builder_config.get('kv_cache_type') + if kv_cache_type is not None: + kv_cache_type = KVCacheType(kv_cache_type) world_size = tp_size * pp_size assert world_size == mpi_world_size(), \ f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})' diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index e6a5d52a82..874c3de49b 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -23,13 +23,13 @@ import torch from .. import profiler from .._utils import mpi_broadcast -from ..bindings import (DataType, GptJsonConfig, KVCacheType, ModelConfig, - WorldConfig) +from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig from ..bindings import executor as trtllm from ..bindings.executor import (DecodingMode, ExternalDraftTokensConfig, OrchestratorConfig, ParallelConfig) from ..builder import EngineConfig from ..layers import MropeParams +from ..llmapi.kv_cache_type import KVCacheType from ..logger import logger from ..mapping import Mapping from .generation import LogitsProcessor, LoraManager @@ -248,7 +248,8 @@ class ModelRunnerCpp(ModelRunnerMixin): json_config = GptJsonConfig.parse_file(config_path) model_config = json_config.model_config - use_kv_cache = model_config.kv_cache_type != KVCacheType.DISABLED + use_kv_cache = KVCacheType.from_cpp( + model_config.kv_cache_type) != KVCacheType.DISABLED if not model_config.use_cross_attention: assert cross_kv_cache_fraction is None, "cross_kv_cache_fraction should only be used with enc-dec models." diff --git a/tests/integration/defs/accuracy/accuracy_core.py b/tests/integration/defs/accuracy/accuracy_core.py index 9907692a9e..7694cdb558 100644 --- a/tests/integration/defs/accuracy/accuracy_core.py +++ b/tests/integration/defs/accuracy/accuracy_core.py @@ -671,7 +671,8 @@ class CliFlowAccuracyTestHarness: f"--max_tokens_in_paged_kv_cache={max_tokens_in_paged_kv_cache}" ]) - if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.max_num_tokens: + if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.model_fields[ + "max_num_tokens"].default: summarize_cmd.append("--enable_chunked_context") if self.extra_summarize_args: diff --git a/tests/integration/defs/llmapi/test_llm_e2e.py b/tests/integration/defs/llmapi/test_llm_e2e.py index a87b78828b..e2f0160d80 100644 --- a/tests/integration/defs/llmapi/test_llm_e2e.py +++ b/tests/integration/defs/llmapi/test_llm_e2e.py @@ -142,14 +142,14 @@ def test_llmapi_build_command_parameters_align(llm_root, llm_venv, engine_dir, with open(os.path.join(engine_dir, "config.json"), "r") as f: engine_config = json.load(f) - build_cmd_cfg = BuildConfig.from_dict( - engine_config["build_config"]).to_dict() + build_cmd_cfg = BuildConfig( + **engine_config["build_config"]).model_dump() with open(os.path.join(tmpdir.name, "config.json"), "r") as f: llm_api_engine_cfg = json.load(f) - build_llmapi_cfg = BuildConfig.from_dict( - llm_api_engine_cfg["build_config"]).to_dict() + build_llmapi_cfg = BuildConfig( + **llm_api_engine_cfg["build_config"]).model_dump() assert build_cmd_cfg == build_llmapi_cfg diff --git a/tests/integration/defs/perf/allowed_configs.py b/tests/integration/defs/perf/allowed_configs.py index 5a4885796c..7bf838ea89 100644 --- a/tests/integration/defs/perf/allowed_configs.py +++ b/tests/integration/defs/perf/allowed_configs.py @@ -1636,7 +1636,7 @@ def get_allowed_models(benchmark_type=None): if i.benchmark_type == benchmark_type) -def get_build_config(model_name, return_dict=True) -> Union[BuildConfig]: +def get_build_config(model_name, return_dict=True) -> Union[Dict, BuildConfig]: if model_name in _allowed_configs: cfg = _allowed_configs[model_name].build_config return asdict(cfg) if return_dict else cfg diff --git a/tests/integration/defs/perf/build.py b/tests/integration/defs/perf/build.py index e4d4ca2101..a12169ba1f 100644 --- a/tests/integration/defs/perf/build.py +++ b/tests/integration/defs/perf/build.py @@ -255,7 +255,7 @@ def get_quant_config(quantization: str): def build_gpt(args): build_config = get_build_config(args.model) - build_config = BuildConfig.from_dict(build_config) + build_config = BuildConfig(**build_config) model_config = get_model_config(args.model) if args.force_num_layer_1: model_config['num_layers'] = 1 @@ -1448,7 +1448,7 @@ def enc_dec_build_helper(component, build_config, model_config, args): def build_enc_dec(args): build_config = get_build_config(args.model) - build_config = BuildConfig.from_dict(build_config) + build_config = BuildConfig(**build_config) model_config = get_model_config(args.model) if args.force_num_layer_1: model_config['num_layers'] = 1 diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index f049a4437c..a48d4594e2 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -9,6 +9,7 @@ import torch from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly import tensorrt_llm.bindings as _tb +from tensorrt_llm.llmapi.kv_cache_type import KVCacheType from tensorrt_llm.mapping import Mapping @@ -85,6 +86,7 @@ def test_model_config(): assert model_config.use_packed_input assert model_config.kv_cache_type is not None + # Test with C++ enums directly for enum_val in [ _tb.KVCacheType.CONTINUOUS, _tb.KVCacheType.PAGED, _tb.KVCacheType.DISABLED @@ -92,6 +94,17 @@ def test_model_config(): model_config.kv_cache_type = enum_val assert model_config.kv_cache_type == enum_val + # Test with Python enums converted to C++ + for py_enum in [ + KVCacheType.CONTINUOUS, KVCacheType.PAGED, KVCacheType.DISABLED + ]: + model_config.kv_cache_type = py_enum.to_cpp() + # Verify it was set correctly by comparing with C++ enum + assert model_config.kv_cache_type == getattr(_tb.KVCacheType, + py_enum.name) + # Also verify round-trip conversion works + assert KVCacheType.from_cpp(model_config.kv_cache_type) == py_enum + assert model_config.tokens_per_block == 64 tokens_per_block = 1024 model_config.tokens_per_block = tokens_per_block diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index f95bed8a9b..2e9c59a948 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -202,7 +202,7 @@ def test_llm_build_config(): # read the build_config and check if the parameters are correctly saved engine_config = json.load(f) - build_config1 = BuildConfig.from_dict(engine_config["build_config"]) + build_config1 = BuildConfig(**engine_config["build_config"]) # Know issue: this will be converted to None after save engine for single-gpu build_config1.plugin_config.nccl_plugin = 'float16' diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index cd646c3378..e8b511024e 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -64,7 +64,7 @@ speculative_config: dict_content = self._yaml_to_dict(yaml_content) llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) llm_args = TrtLlmArgs(**llm_args_dict) assert llm_args.speculative_config.max_window_size == 4 @@ -80,7 +80,7 @@ pytorch_backend_config: # this is deprecated dict_content = self._yaml_to_dict(yaml_content) llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) with pytest.raises(ValueError): llm_args = TrtLlmArgs(**llm_args_dict) @@ -96,7 +96,7 @@ build_config: dict_content = self._yaml_to_dict(yaml_content) llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) llm_args = TrtLlmArgs(**llm_args_dict) assert llm_args.build_config.max_beam_width == 4 @@ -113,7 +113,7 @@ kv_cache_config: dict_content = self._yaml_to_dict(yaml_content) llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) llm_args = TrtLlmArgs(**llm_args_dict) assert llm_args.kv_cache_config.enable_block_reuse == True @@ -131,7 +131,7 @@ max_seq_len: 128 dict_content = self._yaml_to_dict(yaml_content) llm_args = TrtLlmArgs(model=llama_model_path) - llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(), + llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(), dict_content) llm_args = TrtLlmArgs(**llm_args_dict) assert llm_args.max_batch_size == 16 @@ -331,7 +331,7 @@ def test_update_llm_args_with_extra_dict_with_nested_dict(): lora_config=LoraConfig(lora_ckpt_source='hf'), plugin_config=plugin_config) extra_llm_args_dict = { - "build_config": build_config.to_dict(), + "build_config": build_config.model_dump(mode="json"), } llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict, @@ -352,8 +352,9 @@ def test_update_llm_args_with_extra_dict_with_nested_dict(): raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}") return True - build_config_dict1 = build_config.to_dict() - build_config_dict2 = initialized_llm_args.build_config.to_dict() + build_config_dict1 = build_config.model_dump(mode="json") + build_config_dict2 = initialized_llm_args.build_config.model_dump( + mode="json") check_nested_dict_equality(build_config_dict1, build_config_dict2) @@ -498,11 +499,11 @@ class TestTrtLlmArgs: max_num_tokens=256, ) args = TrtLlmArgs(model=llama_model_path, build_config=build_config) - args_dict = args.to_dict() + args_dict = args.model_dump() new_args = TrtLlmArgs.from_kwargs(**args_dict) - assert new_args.to_dict() == args_dict + assert new_args.model_dump() == args_dict def test_build_config_from_engine(self): build_config = BuildConfig(max_batch_size=8, max_num_tokens=256) @@ -522,6 +523,36 @@ class TestTrtLlmArgs: assert args.max_num_tokens == 16 assert args.max_batch_size == 4 + def test_model_dump_does_not_mutate_original(self): + """Test that model_dump() and update_llm_args_with_extra_dict don't mutate the original.""" + # Create args with specific build_config values + build_config = BuildConfig( + max_batch_size=8, + max_num_tokens=256, + ) + args = TrtLlmArgs(model=llama_model_path, build_config=build_config) + + # Store original values + original_max_batch_size = args.build_config.max_batch_size + original_max_num_tokens = args.build_config.max_num_tokens + + # Convert to dict and pass through update_llm_args_with_extra_dict with overrides + args_dict = args.model_dump() + extra_dict = { + "max_batch_size": 128, + "max_num_tokens": 1024, + } + updated_dict = update_llm_args_with_extra_dict(args_dict, extra_dict) + + # Verify original args was NOT mutated + assert args.build_config.max_batch_size == original_max_batch_size + assert args.build_config.max_num_tokens == original_max_num_tokens + + # Verify updated dict has new values + new_args = TrtLlmArgs(**updated_dict) + assert new_args.build_config.max_batch_size == 128 + assert new_args.build_config.max_num_tokens == 1024 + class TestStrictBaseModelArbitraryArgs: """Test that StrictBaseModel prevents arbitrary arguments from being accepted."""