[TRTLLM-8684][chore] Migrate BuildConfig to Pydantic, add a Python wrapper for KVCacheType enum (#8330)

Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
Anish Shanbhag 2025-10-28 09:17:26 -07:00 committed by GitHub
parent cdc9e5e645
commit a09b38a862
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 363 additions and 429 deletions

View File

@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization import QuantMode
@ -97,7 +97,7 @@ def TRTLLaMA(args, config):
quantization_config = pretrained_config['quantization']
build_config = config['build_config']
kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type'])
kv_cache_type = KVCacheType(build_config['kv_cache_type'])
plugin_config = build_config['plugin_config']
dtype = pretrained_config['dtype']

View File

@ -27,7 +27,7 @@ from utils import add_common_args
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner,
SamplingConfig, Session, TensorInfo)
@ -122,8 +122,7 @@ class QWenInfer(object):
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
num_heads)
if "kv_cache_type" in config["build_config"]:
kv_cache_type = KVCacheType.from_string(
config["build_config"]["kv_cache_type"])
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
else:
kv_cache_type = KVCacheType.CONTINUOUS

View File

@ -25,7 +25,7 @@ from vit_onnx_trt import Preprocss
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session,
TensorInfo)
@ -118,8 +118,7 @@ class QWenInfer(object):
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
num_heads)
if "kv_cache_type" in config["build_config"]:
kv_cache_type = KVCacheType.from_string(
config["build_config"]["kv_cache_type"])
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
else:
kv_cache_type = KVCacheType.CONTINUOUS

View File

@ -33,7 +33,8 @@ import tensorrt_llm
import tensorrt_llm.logger as logger
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch)
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType
from tensorrt_llm.bindings import GptJsonConfig
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
from tensorrt_llm.runtime.session import Session, TensorInfo

View File

@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from tensorrt_llm.models.modeling_utils import QuantConfig
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
from ...llmapi.utils import get_type_repr
from .models import ModelFactory, ModelFactoryRegistry
from .utils._config import DynamicYamlMixInForSettings
from .utils.logger import ad_logger
@ -318,12 +317,11 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
model_config = _get_config_dict()
build_config: Optional[object] = Field(
default_factory=lambda: BuildConfig(),
build_config: Optional[BuildConfig] = Field(
default_factory=BuildConfig,
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
exclude_from_json=True,
frozen=True,
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
repr=False,
)
backend: Literal["_autodeploy"] = Field(

View File

@ -22,8 +22,8 @@ TUNED_QUANTS = {
QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES,
QuantAlgo.NO_QUANT, None
}
DEFAULT_MAX_BATCH_SIZE = BuildConfig.max_batch_size
DEFAULT_MAX_NUM_TOKENS = BuildConfig.max_num_tokens
DEFAULT_MAX_BATCH_SIZE = BuildConfig.model_fields["max_batch_size"].default
DEFAULT_MAX_NUM_TOKENS = BuildConfig.model_fields["max_num_tokens"].default
def get_benchmark_engine_settings(

View File

@ -12,27 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import dataclasses
import json
import math
import os
import shutil
import time
from dataclasses import dataclass, field
from functools import cache
from pathlib import Path
from typing import Dict, Optional, Union
import numpy as np
import tensorrt as trt
from pydantic import BaseModel, Field
from ._common import _is_building, check_max_num_tokens, serialize_engine
from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt,
to_json_file, trt_gte)
from .bindings import KVCacheType
from .functional import PositionEmbeddingType
from .graph_rewriting import optimize
from .llmapi.kv_cache_type import KVCacheType
from .logger import logger
from .lora_helper import LoraConfig
from .models import PretrainedConfig, PretrainedModel
@ -46,10 +43,7 @@ from .version import __version__
class ConfigEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, KVCacheType):
# For KVCacheType, convert it to string by split of 'KVCacheType.PAGED'.
return obj.__str__().split('.')[-1]
elif hasattr(obj, 'model_dump'):
if hasattr(obj, 'model_dump'):
# Handle Pydantic models (including DecodingBaseConfig and subclasses)
return obj.model_dump(mode='json')
else:
@ -456,75 +450,112 @@ class Builder():
logger.info(f'Config saved to {config_path}.')
@dataclass
class BuildConfig:
class BuildConfig(BaseModel):
"""Configuration class for TensorRT LLM engine building parameters.
This class contains all the configuration parameters needed to build a TensorRT LLM engine,
including sequence length limits, batch sizes, optimization settings, and various features.
Args:
max_input_len (int): Maximum length of input sequences. Defaults to 1024.
max_seq_len (int, optional): The maximum possible sequence length for a single request, including both input and generated output tokens. Defaults to None.
opt_batch_size (int): Optimal batch size for engine optimization. Defaults to 8.
max_batch_size (int): Maximum batch size the engine can handle. Defaults to 2048.
max_beam_width (int): Maximum beam width for beam search decoding. Defaults to 1.
max_num_tokens (int): Maximum number of batched input tokens after padding is removed in each batch. Defaults to 8192.
opt_num_tokens (int, optional): Optimal number of batched input tokens for engine optimization. Defaults to None.
max_prompt_embedding_table_size (int): Maximum size of prompt embedding table for prompt tuning. Defaults to 0.
kv_cache_type (KVCacheType, optional): Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED. Defaults to None.
gather_context_logits (int): Whether to gather logits during context phase. Defaults to False.
gather_generation_logits (int): Whether to gather logits during generation phase. Defaults to False.
strongly_typed (bool): Whether to use strongly_typed. Defaults to True.
force_num_profiles (int, optional): Force a specific number of optimization profiles. If None, auto-determined. Defaults to None.
profiling_verbosity (str): Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none'). Defaults to 'layer_names_only'.
enable_debug_output (bool): Whether to enable debug output during building. Defaults to False.
max_draft_len (int): Maximum length of draft tokens for speculative decoding. Defaults to 0.
speculative_decoding_mode (SpeculativeDecodingMode): Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.). Defaults to SpeculativeDecodingMode.NONE.
use_refit (bool): Whether to enable engine refitting capabilities. Defaults to False.
input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None.
output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'.
lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig.
weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False.
weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False.
plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig.
use_strip_plan (bool): Whether to use stripped plan for engine building. Defaults to False.
max_encoder_input_len (int): Maximum encoder input length for encoder-decoder models. Defaults to 1024.
dry_run (bool): Whether to perform a dry run without actually building the engine. Defaults to False.
visualize_network (str, optional): Path to save network visualization. If None, no visualization generated. Defaults to None.
monitor_memory (bool): Whether to monitor memory usage during building. Defaults to False.
use_mrope (bool): Whether to use Multi-RoPE (Rotary Position Embedding) optimization. Defaults to False.
"""
max_input_len: int = 1024
max_seq_len: int = None
opt_batch_size: int = 8
max_batch_size: int = 2048
max_beam_width: int = 1
max_num_tokens: int = 8192
opt_num_tokens: Optional[int] = None
max_prompt_embedding_table_size: int = 0
kv_cache_type: KVCacheType = None
gather_context_logits: int = False
gather_generation_logits: int = False
strongly_typed: bool = True
force_num_profiles: Optional[int] = None
profiling_verbosity: str = 'layer_names_only'
enable_debug_output: bool = False
max_draft_len: int = 0
speculative_decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
use_refit: bool = False
input_timing_cache: str = None
output_timing_cache: str = 'model.cache'
lora_config: LoraConfig = field(default_factory=LoraConfig)
weight_sparsity: bool = False
weight_streaming: bool = False
plugin_config: PluginConfig = field(default_factory=PluginConfig)
use_strip_plan: bool = False
max_encoder_input_len: int = 1024 # for enc-dec DecoderModel
dry_run: bool = False
visualize_network: str = None
monitor_memory: bool = False
use_mrope: bool = False
max_input_len: int = Field(default=1024,
description="Maximum length of input sequences.")
max_seq_len: Optional[int] = Field(
default=None,
description=
"The maximum possible sequence length for a single request, including both input and generated "
"output tokens.")
opt_batch_size: int = Field(
default=8, description="Optimal batch size for engine optimization.")
max_batch_size: int = Field(
default=2048, description="Maximum batch size the engine can handle.")
max_beam_width: int = Field(
default=1, description="Maximum beam width for beam search decoding.")
max_num_tokens: int = Field(
default=8192,
description="Maximum number of batched input tokens after padding is "
"removed in each batch.")
opt_num_tokens: Optional[int] = Field(
default=None,
description=
"Optimal number of batched input tokens for engine optimization.")
max_prompt_embedding_table_size: int = Field(
default=0,
description="Maximum size of prompt embedding table for prompt tuning.")
kv_cache_type: Optional[KVCacheType] = Field(
default=None,
description=
"Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED."
)
gather_context_logits: bool = Field(
default=False,
description="Whether to gather logits during context phase.")
gather_generation_logits: bool = Field(
default=False,
description="Whether to gather logits during generation phase.")
strongly_typed: bool = Field(default=True,
description="Whether to use strongly_typed.")
force_num_profiles: Optional[int] = Field(
default=None,
description=
"Force a specific number of optimization profiles. If None, auto-determined."
)
profiling_verbosity: str = Field(
default='layer_names_only',
description=
"Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none')."
)
enable_debug_output: bool = Field(
default=False,
description="Whether to enable debug output during building.")
max_draft_len: int = Field(
default=0,
description="Maximum length of draft tokens for speculative decoding.")
speculative_decoding_mode: SpeculativeDecodingMode = Field(
default=SpeculativeDecodingMode.NONE,
description="Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.)."
)
use_refit: bool = Field(
default=False,
description="Whether to enable engine refitting capabilities.")
input_timing_cache: Optional[str] = Field(
default=None,
description=
"Path to input timing cache file. If None, no input cache used.")
output_timing_cache: str = Field(
default='model.cache', description="Path to output timing cache file.")
lora_config: LoraConfig = Field(
default_factory=LoraConfig,
description="Configuration for LoRA (Low-Rank Adaptation) fine-tuning.")
weight_sparsity: bool = Field(
default=False,
description="Whether to enable weight sparsity optimization.")
weight_streaming: bool = Field(
default=False,
description="Whether to enable weight streaming for large models.")
plugin_config: PluginConfig = Field(
default_factory=PluginConfig,
description="Configuration for TensorRT LLM plugins.")
use_strip_plan: bool = Field(
default=False,
description="Whether to use stripped plan for engine building.")
max_encoder_input_len: int = Field(
default=1024,
description="Maximum encoder input length for encoder-decoder models.")
dry_run: bool = Field(
default=False,
description=
"Whether to perform a dry run without actually building the engine.")
visualize_network: Optional[str] = Field(
default=None,
description=
"Path to save network visualization. If None, no visualization generated."
)
monitor_memory: bool = Field(
default=False,
description="Whether to monitor memory usage during building.")
use_mrope: bool = Field(
default=False,
description=
"Whether to use Multi-RoPE (Rotary Position Embedding) optimization.")
# Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),
# we need to handle it given model architecture.
@ -574,144 +605,10 @@ class BuildConfig:
override_attri('paged_state', False)
@classmethod
@cache
def get_build_config_defaults(cls):
return {
field.name: field.default
for field in dataclasses.fields(cls)
if field.default is not dataclasses.MISSING
}
@classmethod
def from_dict(cls, config, plugin_config=None):
config = copy.deepcopy(
config
) # it just does not make sense to change the input arg `config`
defaults = cls.get_build_config_defaults()
max_input_len = config.pop('max_input_len',
defaults.get('max_input_len'))
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
max_batch_size = config.pop('max_batch_size',
defaults.get('max_batch_size'))
max_beam_width = config.pop('max_beam_width',
defaults.get('max_beam_width'))
max_num_tokens = config.pop('max_num_tokens',
defaults.get('max_num_tokens'))
opt_num_tokens = config.pop('opt_num_tokens',
defaults.get('opt_num_tokens'))
opt_batch_size = config.pop('opt_batch_size',
defaults.get('opt_batch_size'))
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size',
defaults.get('max_prompt_embedding_table_size'))
if "kv_cache_type" in config and config["kv_cache_type"] is not None:
kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type'))
else:
kv_cache_type = None
gather_context_logits = config.pop(
'gather_context_logits', defaults.get('gather_context_logits'))
gather_generation_logits = config.pop(
'gather_generation_logits',
defaults.get('gather_generation_logits'))
strongly_typed = config.pop('strongly_typed',
defaults.get('strongly_typed'))
force_num_profiles = config.pop('force_num_profiles',
defaults.get('force_num_profiles'))
weight_sparsity = config.pop('weight_sparsity',
defaults.get('weight_sparsity'))
profiling_verbosity = config.pop('profiling_verbosity',
defaults.get('profiling_verbosity'))
enable_debug_output = config.pop('enable_debug_output',
defaults.get('enable_debug_output'))
max_draft_len = config.pop('max_draft_len',
defaults.get('max_draft_len'))
speculative_decoding_mode = config.pop(
'speculative_decoding_mode',
defaults.get('speculative_decoding_mode'))
use_refit = config.pop('use_refit', defaults.get('use_refit'))
input_timing_cache = config.pop('input_timing_cache',
defaults.get('input_timing_cache'))
output_timing_cache = config.pop('output_timing_cache',
defaults.get('output_timing_cache'))
lora_config = LoraConfig(**config.get('lora_config', {}))
max_encoder_input_len = config.pop(
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
weight_streaming = config.pop('weight_streaming',
defaults.get('weight_streaming'))
use_strip_plan = config.pop('use_strip_plan',
defaults.get('use_strip_plan'))
if plugin_config is None:
plugin_config = PluginConfig()
if "plugin_config" in config.keys():
plugin_config = plugin_config.model_copy(
update=config["plugin_config"], deep=True)
dry_run = config.pop('dry_run', defaults.get('dry_run'))
visualize_network = config.pop('visualize_network',
defaults.get('visualize_network'))
monitor_memory = config.pop('monitor_memory',
defaults.get('monitor_memory'))
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))
return cls(
max_input_len=max_input_len,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
opt_batch_size=opt_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
kv_cache_type=kv_cache_type,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
strongly_typed=strongly_typed,
force_num_profiles=force_num_profiles,
profiling_verbosity=profiling_verbosity,
enable_debug_output=enable_debug_output,
max_draft_len=max_draft_len,
speculative_decoding_mode=speculative_decoding_mode,
use_refit=use_refit,
input_timing_cache=input_timing_cache,
output_timing_cache=output_timing_cache,
lora_config=lora_config,
use_strip_plan=use_strip_plan,
max_encoder_input_len=max_encoder_input_len,
weight_sparsity=weight_sparsity,
weight_streaming=weight_streaming,
plugin_config=plugin_config,
dry_run=dry_run,
visualize_network=visualize_network,
monitor_memory=monitor_memory,
use_mrope=use_mrope)
@classmethod
def from_json_file(cls, config_file, plugin_config=None):
def from_json_file(cls, config_file):
with open(config_file) as f:
config = json.load(f)
return BuildConfig.from_dict(config, plugin_config=plugin_config)
def to_dict(self):
output = copy.deepcopy(self.__dict__)
# the enum KVCacheType cannot be converted automatically
if output.get('kv_cache_type', None) is not None:
output['kv_cache_type'] = str(output['kv_cache_type'].name)
output['plugin_config'] = output['plugin_config'].model_dump()
output['lora_config'] = output['lora_config'].model_dump()
return output
def update_from_dict(self, config: dict):
for name, value in config.items():
if not hasattr(self, name):
raise AttributeError(
f"{self.__class__} object has no attribute {name}")
setattr(self, name, value)
def update(self, **kwargs):
self.update_from_dict(kwargs)
return BuildConfig(**config)
class EngineConfig:
@ -731,11 +628,10 @@ class EngineConfig:
def from_json_str(cls, config_str):
config = json.loads(config_str)
return cls(PretrainedConfig.from_dict(config['pretrained_config']),
BuildConfig.from_dict(config['build_config']),
config['version'])
BuildConfig(**config['build_config']), config['version'])
def to_dict(self):
build_config = self.build_config.to_dict()
build_config = self.build_config.model_dump(mode="json")
build_config.pop('dry_run', None) # Not an Engine Characteristic
build_config.pop('visualize_network',
None) # Not an Engine Characteristic
@ -1081,7 +977,7 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
'''
tic = time.time()
# avoid changing the input config
build_config = copy.deepcopy(build_config)
build_config = build_config.model_copy(deep=True)
build_config.plugin_config.dtype = model.config.dtype
build_config.update_kv_cache_type(model.config.architecture)

View File

@ -26,8 +26,8 @@ import torch
from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier,
mpi_comm, mpi_rank, mpi_world_size)
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager
@ -37,23 +37,6 @@ from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
from tensorrt_llm.quantization.mode import QuantAlgo
def enum_type(enum_class):
def parse_enum(value):
if isinstance(value, enum_class):
return value
if isinstance(value, str):
return enum_class.from_string(value)
valid_values = [e.name for e in enum_class]
raise argparse.ArgumentTypeError(
f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}"
)
return parse_enum
def parse_arguments():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@ -92,29 +75,30 @@ def parse_arguments():
parser.add_argument(
'--max_batch_size',
type=int,
default=BuildConfig.max_batch_size,
default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.")
parser.add_argument('--max_input_len',
type=int,
default=BuildConfig.max_input_len,
help="Maximum input length of one request.")
parser.add_argument(
'--max_input_len',
type=int,
default=BuildConfig.model_fields["max_input_len"].default,
help="Maximum input length of one request.")
parser.add_argument(
'--max_seq_len',
'--max_decoder_seq_len',
dest='max_seq_len',
type=int,
default=BuildConfig.max_seq_len,
default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.")
parser.add_argument(
'--max_beam_width',
type=int,
default=BuildConfig.max_beam_width,
default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.")
parser.add_argument(
'--max_num_tokens',
type=int,
default=BuildConfig.max_num_tokens,
default=BuildConfig.model_fields["max_num_tokens"].default,
help=
"Maximum number of batched input tokens after padding is removed in each batch. "
"Currently, the input padding is removed by default; "
@ -123,7 +107,7 @@ def parse_arguments():
parser.add_argument(
'--opt_num_tokens',
type=int,
default=BuildConfig.opt_num_tokens,
default=BuildConfig.model_fields["opt_num_tokens"].default,
help=
"Optimal number of batched input tokens after padding is removed in each batch "
"It equals to ``max_batch_size * max_beam_width`` by default, set this "
@ -132,7 +116,7 @@ def parse_arguments():
parser.add_argument(
'--max_encoder_input_len',
type=int,
default=BuildConfig.max_encoder_input_len,
default=BuildConfig.model_fields["max_encoder_input_len"].default,
help="Maximum encoder input length for enc-dec models. "
"Set ``max_input_len`` to 1 to start generation from decoder_start_token_id of length 1."
)
@ -140,14 +124,15 @@ def parse_arguments():
'--max_prompt_embedding_table_size',
'--max_multimodal_len',
type=int,
default=BuildConfig.max_prompt_embedding_table_size,
default=BuildConfig.model_fields["max_prompt_embedding_table_size"].
default,
help=
"Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. "
"Setting a value > 0 enables prompt tuning or multimodal input.")
parser.add_argument(
'--kv_cache_type',
default=argparse.SUPPRESS,
type=enum_type(KVCacheType),
type=KVCacheType,
help=
"Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed."
)
@ -156,42 +141,44 @@ def parse_arguments():
type=str,
default=argparse.SUPPRESS,
help=
"Deprecated. Enabling this option is equvilient to ``--kv_cache_type paged`` for transformer based models."
"Deprecated. Enabling this option is equivalent to ``--kv_cache_type paged`` for transformer based models."
)
parser.add_argument(
'--input_timing_cache',
type=str,
default=BuildConfig.input_timing_cache,
default=BuildConfig.model_fields["input_timing_cache"].default,
help=
"The file path to read the timing cache. This option is ignored if the file does not exist."
)
parser.add_argument('--output_timing_cache',
type=str,
default=BuildConfig.output_timing_cache,
help="The file path to write the timing cache.")
parser.add_argument(
'--output_timing_cache',
type=str,
default=BuildConfig.model_fields["output_timing_cache"].default,
help="The file path to write the timing cache.")
parser.add_argument(
'--profiling_verbosity',
type=str,
default=BuildConfig.profiling_verbosity,
default=BuildConfig.model_fields["profiling_verbosity"].default,
choices=['layer_names_only', 'detailed', 'none'],
help=
"The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters."
)
parser.add_argument(
'--strip_plan',
default=BuildConfig.use_strip_plan,
default=BuildConfig.model_fields["use_strip_plan"].default,
action='store_true',
help=
"Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time."
)
parser.add_argument('--weight_sparsity',
default=BuildConfig.weight_sparsity,
action='store_true',
help="Enable weight sparsity.")
parser.add_argument(
'--weight_sparsity',
default=BuildConfig.model_fields["weight_sparsity"].default,
action='store_true',
help="Enable weight sparsity.")
parser.add_argument(
'--weight_streaming',
default=BuildConfig.weight_streaming,
default=BuildConfig.model_fields["weight_streaming"].default,
action='store_true',
help=
"Enable offloading weights to CPU and streaming loading at runtime.",
@ -213,10 +200,11 @@ def parse_arguments():
default='info',
choices=severity_map.keys(),
help="The logging level.")
parser.add_argument('--enable_debug_output',
default=BuildConfig.enable_debug_output,
action='store_true',
help="Enable debug output.")
parser.add_argument(
'--enable_debug_output',
default=BuildConfig.model_fields["enable_debug_output"].default,
action='store_true',
help="Enable debug output.")
parser.add_argument(
'--visualize_network',
type=str,
@ -226,7 +214,7 @@ def parse_arguments():
)
parser.add_argument(
'--dry_run',
default=BuildConfig.dry_run,
default=BuildConfig.model_fields["dry_run"].default,
action='store_true',
help=
"Run through the build process except the actual Engine build for debugging."
@ -519,65 +507,37 @@ def main():
f"Overriding # of builder profiles <= {force_num_profiles_from_env}."
)
build_config = BuildConfig.from_dict(
{
'max_input_len':
args.max_input_len,
'max_seq_len':
args.max_seq_len,
'max_batch_size':
args.max_batch_size,
'max_beam_width':
args.max_beam_width,
'max_num_tokens':
args.max_num_tokens,
'opt_num_tokens':
args.opt_num_tokens,
'max_prompt_embedding_table_size':
args.max_prompt_embedding_table_size,
'gather_context_logits':
args.gather_context_logits,
'gather_generation_logits':
args.gather_generation_logits,
'strongly_typed':
True,
'force_num_profiles':
force_num_profiles_from_env,
'weight_sparsity':
args.weight_sparsity,
'profiling_verbosity':
args.profiling_verbosity,
'enable_debug_output':
args.enable_debug_output,
'max_draft_len':
args.max_draft_len,
'speculative_decoding_mode':
speculative_decoding_mode,
'input_timing_cache':
args.input_timing_cache,
'output_timing_cache':
args.output_timing_cache,
'dry_run':
args.dry_run,
'visualize_network':
args.visualize_network,
'max_encoder_input_len':
args.max_encoder_input_len,
'weight_streaming':
args.weight_streaming,
'monitor_memory':
args.monitor_memory,
'use_mrope':
(True if model_config.qwen_type == "qwen2_vl" else False)
if hasattr(model_config, "qwen_type") else False
},
build_config = BuildConfig(
max_input_len=args.max_input_len,
max_seq_len=args.max_seq_len,
max_batch_size=args.max_batch_size,
max_beam_width=args.max_beam_width,
max_num_tokens=args.max_num_tokens,
opt_num_tokens=args.opt_num_tokens,
max_prompt_embedding_table_size=args.
max_prompt_embedding_table_size,
kv_cache_type=getattr(args, "kv_cache_type", None),
gather_context_logits=args.gather_context_logits,
gather_generation_logits=args.gather_generation_logits,
strongly_typed=True,
force_num_profiles=force_num_profiles_from_env,
weight_sparsity=args.weight_sparsity,
profiling_verbosity=args.profiling_verbosity,
enable_debug_output=args.enable_debug_output,
max_draft_len=args.max_draft_len,
speculative_decoding_mode=speculative_decoding_mode,
input_timing_cache=args.input_timing_cache,
output_timing_cache=args.output_timing_cache,
dry_run=args.dry_run,
visualize_network=args.visualize_network,
max_encoder_input_len=args.max_encoder_input_len,
weight_streaming=args.weight_streaming,
monitor_memory=args.monitor_memory,
use_mrope=getattr(model_config, "qwen_type", None) == "qwen2_vl",
plugin_config=plugin_config)
if hasattr(args, 'kv_cache_type'):
build_config.update_from_dict({'kv_cache_type': args.kv_cache_type})
else:
build_config = BuildConfig.from_json_file(args.build_config,
plugin_config=plugin_config)
build_config = BuildConfig.from_json_file(args.build_config)
build_config.plugin_config = plugin_config
parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
workers, args.log_level, model_cls, **kwargs)

View File

@ -50,23 +50,23 @@ from ..logger import logger, severity_map
help="The logging level.")
@click.option("--max_beam_width",
type=int,
default=BuildConfig.max_beam_width,
default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.")
@click.option("--max_batch_size",
type=int,
default=BuildConfig.max_batch_size,
default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.")
@click.option(
"--max_num_tokens",
type=int,
default=BuildConfig.max_num_tokens,
default=BuildConfig.model_fields["max_num_tokens"].default,
help=
"Maximum number of batched input tokens after padding is removed in each batch."
)
@click.option(
"--max_seq_len",
type=int,
default=BuildConfig.max_seq_len,
default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.")
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')

View File

@ -2,7 +2,6 @@
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
'''
import argparse
import copy
import json
import os
import re
@ -57,7 +56,7 @@ def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
# There are weights preprocess during optimize model.
tik = time.time()
build_config = copy.deepcopy(engine_config.build_config)
build_config = engine_config.build_config.model_copy(deep=True)
optimize_model_with_config(model, build_config)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))

View File

@ -75,25 +75,29 @@ def _signal_handler_cleanup_child(signum, frame):
sys.exit(128 + signum)
def get_llm_args(model: str,
tokenizer: Optional[str] = None,
backend: str = "pytorch",
max_beam_width: int = BuildConfig.max_beam_width,
max_batch_size: int = BuildConfig.max_batch_size,
max_num_tokens: int = BuildConfig.max_num_tokens,
max_seq_len: int = BuildConfig.max_seq_len,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
moe_expert_parallel_size: Optional[int] = None,
gpus_per_node: Optional[int] = None,
free_gpu_memory_fraction: float = 0.9,
num_postprocess_workers: int = 0,
trust_remote_code: bool = False,
reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
enable_chunked_prefill: bool = False,
**llm_args_extra_dict: Any):
def get_llm_args(
model: str,
tokenizer: Optional[str] = None,
backend: str = "pytorch",
max_beam_width: int = BuildConfig.model_fields["max_beam_width"].
default,
max_batch_size: int = BuildConfig.model_fields["max_batch_size"].
default,
max_num_tokens: int = BuildConfig.model_fields["max_num_tokens"].
default,
max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
moe_expert_parallel_size: Optional[int] = None,
gpus_per_node: Optional[int] = None,
free_gpu_memory_fraction: float = 0.9,
num_postprocess_workers: int = 0,
trust_remote_code: bool = False,
reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
enable_chunked_prefill: bool = False,
**llm_args_extra_dict: Any):
if gpus_per_node is None:
gpus_per_node = device_count()
@ -242,23 +246,23 @@ class ChoiceWithAlias(click.Choice):
help="The logging level.")
@click.option("--max_beam_width",
type=int,
default=BuildConfig.max_beam_width,
default=BuildConfig.model_fields["max_beam_width"].default,
help="Maximum number of beams for beam search decoding.")
@click.option("--max_batch_size",
type=int,
default=BuildConfig.max_batch_size,
default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.")
@click.option(
"--max_num_tokens",
type=int,
default=BuildConfig.max_num_tokens,
default=BuildConfig.model_fields["max_num_tokens"].default,
help=
"Maximum number of batched input tokens after padding is removed in each batch."
)
@click.option(
"--max_seq_len",
type=int,
default=BuildConfig.max_seq_len,
default=BuildConfig.model_fields["max_seq_len"].default,
help="Maximum total length of one request, including prompt and outputs. "
"If unspecified, the value is deduced from the model config.")
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')
@ -436,7 +440,7 @@ def serve(
help="The logging level.")
@click.option("--max_batch_size",
type=int,
default=BuildConfig.max_batch_size,
default=BuildConfig.model_fields["max_batch_size"].default,
help="Maximum number of requests that the engine can schedule.")
@click.option(
"--max_num_tokens",

View File

@ -104,7 +104,7 @@ class BuildCache:
Get the build step for engine building.
'''
build_config_str = json.dumps(self.prune_build_config_for_cache_key(
build_config.to_dict()),
build_config.model_dump(mode="json")),
sort_keys=True)
kwargs_str = json.dumps(kwargs, sort_keys=True)

View File

@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import tensorrt_llm.bindings as _bindings
class KVCacheType(str, Enum):
"""Python enum wrapper for KVCacheType.
This is a pure Python enum that mirrors the C++ KVCacheType enum exposed
through pybind11.
"""
CONTINUOUS = "continuous"
PAGED = "paged"
DISABLED = "disabled"
@classmethod
def _missing_(cls, value):
"""Allow case-insensitive string values to be converted to enum members."""
if isinstance(value, str):
for member in cls:
if member.value.lower() == value.lower():
return member
return None
def to_cpp(self) -> '_bindings.KVCacheType':
import tensorrt_llm.bindings as _bindings
return getattr(_bindings.KVCacheType, self.name)
@classmethod
def from_cpp(cls, cpp_enum) -> 'KVCacheType':
# C++ enum's __str__ returns "KVCacheType.PAGED", extract the name
name = str(cpp_enum).split('.')[-1]
return cls(name)

View File

@ -1,5 +1,4 @@
import ast
import copy
import functools
import json
import math
@ -1764,17 +1763,6 @@ class BaseLlmArgs(StrictBaseModel):
ret = cls(**kwargs)
return ret
def to_dict(self) -> dict:
"""Dump `LlmArgs` instance to a dict.
Returns:
dict: The dict that contains all fields of the `LlmArgs` instance.
"""
model_dict = self.model_dump(mode='json')
# TODO: the BuildConfig.to_dict and from_dict don't work well with pydantic
model_dict['build_config'] = copy.deepcopy(self.build_config)
return model_dict
@staticmethod
def _check_consistency(kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
# max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during
@ -1919,10 +1907,6 @@ class BaseLlmArgs(StrictBaseModel):
if self.max_input_len:
kwargs["max_input_len"] = self.max_input_len
self.build_config = BuildConfig(**kwargs)
else:
assert isinstance(
build_config,
BuildConfig), f"build_config is not initialized: {build_config}"
return self
@model_validator(mode="after")
@ -2001,7 +1985,7 @@ class BaseLlmArgs(StrictBaseModel):
# TODO: remove the checker when manage weights support all data types
if is_trt_llm_args and self.fast_build and (self.quant_config.quant_algo
is QuantAlgo.FP8):
self._update_plugin_config("manage_weights", True)
self.build_config.plugin_config.manage_weights = True
if self.parallel_config.world_size == 1 and self.build_config:
self.build_config.plugin_config.nccl_plugin = None
@ -2166,9 +2150,6 @@ class BaseLlmArgs(StrictBaseModel):
"while LoRA prefetch is not supported")
return self
def _update_plugin_config(self, key: str, value: Any):
setattr(self.build_config.plugin_config, key, value)
def _load_config_from_engine(self, engine_dir: Path):
engine_config = EngineConfig.from_json_file(engine_dir / "config.json")
self._pretrained_config = engine_config.pretrained_config
@ -2271,10 +2252,8 @@ class TrtLlmArgs(BaseLlmArgs):
fast_build: bool = Field(default=False, description="Enable fast build.")
# BuildConfig is introduced to give users a familiar interface to configure the model building.
build_config: Optional[object] = Field(
default=None,
description="Build config.",
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"})
build_config: Optional[BuildConfig] = Field(default=None,
description="Build config.")
# Prompt adapter arguments
enable_prompt_adapter: bool = Field(default=False,
@ -2405,11 +2384,10 @@ class TorchCompileConfig(StrictBaseModel):
class TorchLlmArgs(BaseLlmArgs):
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
build_config: Optional[object] = Field(
build_config: Optional[BuildConfig] = Field(
default=None,
description="Build config.",
exclude_from_json=True,
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
status="deprecated",
)
@ -2911,10 +2889,7 @@ def update_llm_args_with_extra_dict(
for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict:
# Some fields need to be converted manually.
if field_name in [
"speculative_config", "build_config",
"sparse_attention_config"
]:
if field_name in ["speculative_config", "sparse_attention_config"]:
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name])
else:
@ -2928,6 +2903,10 @@ def update_llm_args_with_extra_dict(
# For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch
# backend, overwriting the knobs there since build_config always has the highest priority
if "build_config" in llm_args:
# Ensure build_config is a BuildConfig object, not a dict
if isinstance(llm_args["build_config"], dict):
llm_args["build_config"] = BuildConfig(**llm_args["build_config"])
for key in [
"max_batch_size",
"max_num_tokens",

View File

@ -1,4 +1,3 @@
import copy
import json
import os
import shutil
@ -530,8 +529,8 @@ class ModelLoader:
logger_debug(f"rank{mpi_rank()} begin to build engine...\n", "green")
# avoid the original build_config is modified, avoid the side effect
copied_build_config = copy.deepcopy(self.build_config)
# avoid side effects by copying the original build_config
copied_build_config = self.build_config.model_copy(deep=True)
copied_build_config.update_kv_cache_type(self._model_info.architecture)
assert self.model is not None, "model is loaded yet."

View File

@ -27,10 +27,10 @@ from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from ..._common import default_net, default_trtnet
from ..._utils import pad_vocab_size
from ...bindings import KVCacheType
from ...functional import (Tensor, _create_tensor, cast, concat,
gather_last_token_logits, index_select, shape)
from ...layers import AttentionParams, ColumnLinear, SpecDecodingParams
from ...llmapi.kv_cache_type import KVCacheType
from ...module import Module, ModuleList
from ...plugin import TRT_LLM_PLUGIN_NAMESPACE
from ..modeling_utils import QuantConfig

View File

@ -18,9 +18,9 @@ from typing import List, Optional
import tensorrt as trt
from ..bindings import KVCacheType
from ..functional import Tensor
from ..layers import MropeParams, SpecDecodingParams
from ..llmapi.kv_cache_type import KVCacheType
from ..mapping import Mapping
from ..plugin import current_all_reduce_helper

View File

@ -21,7 +21,6 @@ import torch
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (Conditional, LayerNormPositionType,
LayerNormType, MLPType,
PositionEmbeddingType, Tensor, assertion,
@ -32,6 +31,7 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, GroupNorm, KeyValueCacheParams,
LayerNorm, LoraParams, RmsNorm)
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)

View File

@ -19,7 +19,6 @@ from .._common import default_net
from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch,
release_gc, str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch)
from ..bindings import KVCacheType
from ..bindings.executor import RuntimeDefaults
from ..functional import (PositionEmbeddingType, Tensor, allgather, constant,
cp_split_plugin, gather_last_token_logits,
@ -31,6 +30,7 @@ from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear
from ..layers.lora import Dora, Lora
from ..layers.moe import MOE, MoeOOTB
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger
from ..mapping import Mapping
from ..module import Module, ModuleList

View File

@ -15,7 +15,6 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type, Union
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AttentionMaskType, PositionEmbeddingType,
Tensor, gather_last_token_logits, recv,
@ -28,6 +27,7 @@ from tensorrt_llm.layers.linear import ColumnLinear
from tensorrt_llm.layers.lora import LoraParams
from tensorrt_llm.layers.mlp import GatedMLP
from tensorrt_llm.layers.normalization import RmsNorm
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import has_safetensors
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM

View File

@ -18,8 +18,8 @@ from collections import OrderedDict
import tensorrt as trt
from tensorrt_llm._common import default_net
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import Tensor, cast, categorical_sample
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.models import LLaMAForCausalLM, QWenForCausalLM
from tensorrt_llm.models.generation_mixin import GenerationMixin

View File

@ -31,7 +31,6 @@ from tqdm import tqdm
import tensorrt_llm
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType,
PositionEmbeddingType, Tensor,
constant_to_tensor_)
@ -41,6 +40,7 @@ from tensorrt_llm.layers.attention import (Attention, AttentionParams,
BertAttention, KeyValueCacheParams,
bert_attention, layernorm_map)
from tensorrt_llm.layers.normalization import RmsNorm
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.generation_mixin import GenerationMixin
from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat,

View File

@ -43,8 +43,9 @@ from tensorrt_llm.runtime.redrafter_utils import *
from .._utils import (binding_layer_type_to_str, binding_to_str_dtype,
pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
trt_dtype_to_torch)
from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free
from ..bindings import ipc_nvls_allocate, ipc_nvls_free
from ..layers import LanguageAdapterConfig
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger
from ..lora_manager import LoraManager
from ..mapping import Mapping

View File

@ -25,9 +25,10 @@ import torch
from .. import profiler
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch
from ..bindings import KVCacheType, MpiComm
from ..bindings import MpiComm
from ..bindings.executor import Executor
from ..builder import Engine, EngineConfig, get_engine_version
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger
from ..mapping import Mapping
from ..quantization import QuantMode
@ -86,7 +87,9 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]:
dtype = builder_config['precision']
tp_size = builder_config['tensor_parallel']
pp_size = builder_config.get('pipeline_parallel', 1)
kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type'))
kv_cache_type = builder_config.get('kv_cache_type')
if kv_cache_type is not None:
kv_cache_type = KVCacheType(kv_cache_type)
world_size = tp_size * pp_size
assert world_size == mpi_world_size(), \
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})'

View File

@ -23,13 +23,13 @@ import torch
from .. import profiler
from .._utils import mpi_broadcast
from ..bindings import (DataType, GptJsonConfig, KVCacheType, ModelConfig,
WorldConfig)
from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig
from ..bindings import executor as trtllm
from ..bindings.executor import (DecodingMode, ExternalDraftTokensConfig,
OrchestratorConfig, ParallelConfig)
from ..builder import EngineConfig
from ..layers import MropeParams
from ..llmapi.kv_cache_type import KVCacheType
from ..logger import logger
from ..mapping import Mapping
from .generation import LogitsProcessor, LoraManager
@ -248,7 +248,8 @@ class ModelRunnerCpp(ModelRunnerMixin):
json_config = GptJsonConfig.parse_file(config_path)
model_config = json_config.model_config
use_kv_cache = model_config.kv_cache_type != KVCacheType.DISABLED
use_kv_cache = KVCacheType.from_cpp(
model_config.kv_cache_type) != KVCacheType.DISABLED
if not model_config.use_cross_attention:
assert cross_kv_cache_fraction is None, "cross_kv_cache_fraction should only be used with enc-dec models."

View File

@ -671,7 +671,8 @@ class CliFlowAccuracyTestHarness:
f"--max_tokens_in_paged_kv_cache={max_tokens_in_paged_kv_cache}"
])
if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.max_num_tokens:
if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.model_fields[
"max_num_tokens"].default:
summarize_cmd.append("--enable_chunked_context")
if self.extra_summarize_args:

View File

@ -142,14 +142,14 @@ def test_llmapi_build_command_parameters_align(llm_root, llm_venv, engine_dir,
with open(os.path.join(engine_dir, "config.json"), "r") as f:
engine_config = json.load(f)
build_cmd_cfg = BuildConfig.from_dict(
engine_config["build_config"]).to_dict()
build_cmd_cfg = BuildConfig(
**engine_config["build_config"]).model_dump()
with open(os.path.join(tmpdir.name, "config.json"), "r") as f:
llm_api_engine_cfg = json.load(f)
build_llmapi_cfg = BuildConfig.from_dict(
llm_api_engine_cfg["build_config"]).to_dict()
build_llmapi_cfg = BuildConfig(
**llm_api_engine_cfg["build_config"]).model_dump()
assert build_cmd_cfg == build_llmapi_cfg

View File

@ -1636,7 +1636,7 @@ def get_allowed_models(benchmark_type=None):
if i.benchmark_type == benchmark_type)
def get_build_config(model_name, return_dict=True) -> Union[BuildConfig]:
def get_build_config(model_name, return_dict=True) -> Union[Dict, BuildConfig]:
if model_name in _allowed_configs:
cfg = _allowed_configs[model_name].build_config
return asdict(cfg) if return_dict else cfg

View File

@ -255,7 +255,7 @@ def get_quant_config(quantization: str):
def build_gpt(args):
build_config = get_build_config(args.model)
build_config = BuildConfig.from_dict(build_config)
build_config = BuildConfig(**build_config)
model_config = get_model_config(args.model)
if args.force_num_layer_1:
model_config['num_layers'] = 1
@ -1448,7 +1448,7 @@ def enc_dec_build_helper(component, build_config, model_config, args):
def build_enc_dec(args):
build_config = get_build_config(args.model)
build_config = BuildConfig.from_dict(build_config)
build_config = BuildConfig(**build_config)
model_config = get_model_config(args.model)
if args.force_num_layer_1:
model_config['num_layers'] = 1

View File

@ -9,6 +9,7 @@ import torch
from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly
import tensorrt_llm.bindings as _tb
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
from tensorrt_llm.mapping import Mapping
@ -85,6 +86,7 @@ def test_model_config():
assert model_config.use_packed_input
assert model_config.kv_cache_type is not None
# Test with C++ enums directly
for enum_val in [
_tb.KVCacheType.CONTINUOUS, _tb.KVCacheType.PAGED,
_tb.KVCacheType.DISABLED
@ -92,6 +94,17 @@ def test_model_config():
model_config.kv_cache_type = enum_val
assert model_config.kv_cache_type == enum_val
# Test with Python enums converted to C++
for py_enum in [
KVCacheType.CONTINUOUS, KVCacheType.PAGED, KVCacheType.DISABLED
]:
model_config.kv_cache_type = py_enum.to_cpp()
# Verify it was set correctly by comparing with C++ enum
assert model_config.kv_cache_type == getattr(_tb.KVCacheType,
py_enum.name)
# Also verify round-trip conversion works
assert KVCacheType.from_cpp(model_config.kv_cache_type) == py_enum
assert model_config.tokens_per_block == 64
tokens_per_block = 1024
model_config.tokens_per_block = tokens_per_block

View File

@ -202,7 +202,7 @@ def test_llm_build_config():
# read the build_config and check if the parameters are correctly saved
engine_config = json.load(f)
build_config1 = BuildConfig.from_dict(engine_config["build_config"])
build_config1 = BuildConfig(**engine_config["build_config"])
# Know issue: this will be converted to None after save engine for single-gpu
build_config1.plugin_config.nccl_plugin = 'float16'

View File

@ -64,7 +64,7 @@ speculative_config:
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4
@ -80,7 +80,7 @@ pytorch_backend_config: # this is deprecated
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
with pytest.raises(ValueError):
llm_args = TrtLlmArgs(**llm_args_dict)
@ -96,7 +96,7 @@ build_config:
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.build_config.max_beam_width == 4
@ -113,7 +113,7 @@ kv_cache_config:
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.kv_cache_config.enable_block_reuse == True
@ -131,7 +131,7 @@ max_seq_len: 128
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.max_batch_size == 16
@ -331,7 +331,7 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
lora_config=LoraConfig(lora_ckpt_source='hf'),
plugin_config=plugin_config)
extra_llm_args_dict = {
"build_config": build_config.to_dict(),
"build_config": build_config.model_dump(mode="json"),
}
llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
@ -352,8 +352,9 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True
build_config_dict1 = build_config.to_dict()
build_config_dict2 = initialized_llm_args.build_config.to_dict()
build_config_dict1 = build_config.model_dump(mode="json")
build_config_dict2 = initialized_llm_args.build_config.model_dump(
mode="json")
check_nested_dict_equality(build_config_dict1, build_config_dict2)
@ -498,11 +499,11 @@ class TestTrtLlmArgs:
max_num_tokens=256,
)
args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
args_dict = args.to_dict()
args_dict = args.model_dump()
new_args = TrtLlmArgs.from_kwargs(**args_dict)
assert new_args.to_dict() == args_dict
assert new_args.model_dump() == args_dict
def test_build_config_from_engine(self):
build_config = BuildConfig(max_batch_size=8, max_num_tokens=256)
@ -522,6 +523,36 @@ class TestTrtLlmArgs:
assert args.max_num_tokens == 16
assert args.max_batch_size == 4
def test_model_dump_does_not_mutate_original(self):
"""Test that model_dump() and update_llm_args_with_extra_dict don't mutate the original."""
# Create args with specific build_config values
build_config = BuildConfig(
max_batch_size=8,
max_num_tokens=256,
)
args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
# Store original values
original_max_batch_size = args.build_config.max_batch_size
original_max_num_tokens = args.build_config.max_num_tokens
# Convert to dict and pass through update_llm_args_with_extra_dict with overrides
args_dict = args.model_dump()
extra_dict = {
"max_batch_size": 128,
"max_num_tokens": 1024,
}
updated_dict = update_llm_args_with_extra_dict(args_dict, extra_dict)
# Verify original args was NOT mutated
assert args.build_config.max_batch_size == original_max_batch_size
assert args.build_config.max_num_tokens == original_max_num_tokens
# Verify updated dict has new values
new_args = TrtLlmArgs(**updated_dict)
assert new_args.build_config.max_batch_size == 128
assert new_args.build_config.max_num_tokens == 1024
class TestStrictBaseModelArbitraryArgs:
"""Test that StrictBaseModel prevents arbitrary arguments from being accepted."""