mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8684][chore] Migrate BuildConfig to Pydantic, add a Python wrapper for KVCacheType enum (#8330)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
parent
cdc9e5e645
commit
a09b38a862
@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.profiler as profiler
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
@ -97,7 +97,7 @@ def TRTLLaMA(args, config):
|
||||
quantization_config = pretrained_config['quantization']
|
||||
|
||||
build_config = config['build_config']
|
||||
kv_cache_type = KVCacheType.from_string(build_config['kv_cache_type'])
|
||||
kv_cache_type = KVCacheType(build_config['kv_cache_type'])
|
||||
plugin_config = build_config['plugin_config']
|
||||
|
||||
dtype = pretrained_config['dtype']
|
||||
|
||||
@ -27,7 +27,7 @@ from utils import add_common_args
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.profiler as profiler
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.runtime import (PYTHON_BINDINGS, ModelConfig, ModelRunner,
|
||||
SamplingConfig, Session, TensorInfo)
|
||||
@ -122,8 +122,7 @@ class QWenInfer(object):
|
||||
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
|
||||
num_heads)
|
||||
if "kv_cache_type" in config["build_config"]:
|
||||
kv_cache_type = KVCacheType.from_string(
|
||||
config["build_config"]["kv_cache_type"])
|
||||
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
|
||||
else:
|
||||
kv_cache_type = KVCacheType.CONTINUOUS
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from vit_onnx_trt import Preprocss
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.profiler as profiler
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.runtime import (ModelConfig, SamplingConfig, Session,
|
||||
TensorInfo)
|
||||
@ -118,8 +118,7 @@ class QWenInfer(object):
|
||||
num_kv_heads = config["pretrained_config"].get("num_key_value_heads",
|
||||
num_heads)
|
||||
if "kv_cache_type" in config["build_config"]:
|
||||
kv_cache_type = KVCacheType.from_string(
|
||||
config["build_config"]["kv_cache_type"])
|
||||
kv_cache_type = KVCacheType(config["build_config"]["kv_cache_type"])
|
||||
else:
|
||||
kv_cache_type = KVCacheType.CONTINUOUS
|
||||
|
||||
|
||||
@ -33,7 +33,8 @@ import tensorrt_llm
|
||||
import tensorrt_llm.logger as logger
|
||||
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
||||
trt_dtype_to_torch)
|
||||
from tensorrt_llm.bindings import GptJsonConfig, KVCacheType
|
||||
from tensorrt_llm.bindings import GptJsonConfig
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelConfig, SamplingConfig
|
||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
|
||||
from ...llmapi.utils import get_type_repr
|
||||
from .models import ModelFactory, ModelFactoryRegistry
|
||||
from .utils._config import DynamicYamlMixInForSettings
|
||||
from .utils.logger import ad_logger
|
||||
@ -318,12 +317,11 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
|
||||
model_config = _get_config_dict()
|
||||
|
||||
build_config: Optional[object] = Field(
|
||||
default_factory=lambda: BuildConfig(),
|
||||
build_config: Optional[BuildConfig] = Field(
|
||||
default_factory=BuildConfig,
|
||||
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
|
||||
exclude_from_json=True,
|
||||
frozen=True,
|
||||
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
|
||||
repr=False,
|
||||
)
|
||||
backend: Literal["_autodeploy"] = Field(
|
||||
|
||||
@ -22,8 +22,8 @@ TUNED_QUANTS = {
|
||||
QuantAlgo.NVFP4, QuantAlgo.FP8, QuantAlgo.FP8_BLOCK_SCALES,
|
||||
QuantAlgo.NO_QUANT, None
|
||||
}
|
||||
DEFAULT_MAX_BATCH_SIZE = BuildConfig.max_batch_size
|
||||
DEFAULT_MAX_NUM_TOKENS = BuildConfig.max_num_tokens
|
||||
DEFAULT_MAX_BATCH_SIZE = BuildConfig.model_fields["max_batch_size"].default
|
||||
DEFAULT_MAX_NUM_TOKENS = BuildConfig.model_fields["max_num_tokens"].default
|
||||
|
||||
|
||||
def get_benchmark_engine_settings(
|
||||
|
||||
@ -12,27 +12,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._common import _is_building, check_max_num_tokens, serialize_engine
|
||||
from ._utils import (get_sm_version, np_bfloat16, np_float8, str_dtype_to_trt,
|
||||
to_json_file, trt_gte)
|
||||
from .bindings import KVCacheType
|
||||
from .functional import PositionEmbeddingType
|
||||
from .graph_rewriting import optimize
|
||||
from .llmapi.kv_cache_type import KVCacheType
|
||||
from .logger import logger
|
||||
from .lora_helper import LoraConfig
|
||||
from .models import PretrainedConfig, PretrainedModel
|
||||
@ -46,10 +43,7 @@ from .version import __version__
|
||||
class ConfigEncoder(json.JSONEncoder):
|
||||
|
||||
def default(self, obj):
|
||||
if isinstance(obj, KVCacheType):
|
||||
# For KVCacheType, convert it to string by split of 'KVCacheType.PAGED'.
|
||||
return obj.__str__().split('.')[-1]
|
||||
elif hasattr(obj, 'model_dump'):
|
||||
if hasattr(obj, 'model_dump'):
|
||||
# Handle Pydantic models (including DecodingBaseConfig and subclasses)
|
||||
return obj.model_dump(mode='json')
|
||||
else:
|
||||
@ -456,75 +450,112 @@ class Builder():
|
||||
logger.info(f'Config saved to {config_path}.')
|
||||
|
||||
|
||||
@dataclass
|
||||
class BuildConfig:
|
||||
class BuildConfig(BaseModel):
|
||||
"""Configuration class for TensorRT LLM engine building parameters.
|
||||
|
||||
This class contains all the configuration parameters needed to build a TensorRT LLM engine,
|
||||
including sequence length limits, batch sizes, optimization settings, and various features.
|
||||
|
||||
Args:
|
||||
max_input_len (int): Maximum length of input sequences. Defaults to 1024.
|
||||
max_seq_len (int, optional): The maximum possible sequence length for a single request, including both input and generated output tokens. Defaults to None.
|
||||
opt_batch_size (int): Optimal batch size for engine optimization. Defaults to 8.
|
||||
max_batch_size (int): Maximum batch size the engine can handle. Defaults to 2048.
|
||||
max_beam_width (int): Maximum beam width for beam search decoding. Defaults to 1.
|
||||
max_num_tokens (int): Maximum number of batched input tokens after padding is removed in each batch. Defaults to 8192.
|
||||
opt_num_tokens (int, optional): Optimal number of batched input tokens for engine optimization. Defaults to None.
|
||||
max_prompt_embedding_table_size (int): Maximum size of prompt embedding table for prompt tuning. Defaults to 0.
|
||||
kv_cache_type (KVCacheType, optional): Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED. Defaults to None.
|
||||
gather_context_logits (int): Whether to gather logits during context phase. Defaults to False.
|
||||
gather_generation_logits (int): Whether to gather logits during generation phase. Defaults to False.
|
||||
strongly_typed (bool): Whether to use strongly_typed. Defaults to True.
|
||||
force_num_profiles (int, optional): Force a specific number of optimization profiles. If None, auto-determined. Defaults to None.
|
||||
profiling_verbosity (str): Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none'). Defaults to 'layer_names_only'.
|
||||
enable_debug_output (bool): Whether to enable debug output during building. Defaults to False.
|
||||
max_draft_len (int): Maximum length of draft tokens for speculative decoding. Defaults to 0.
|
||||
speculative_decoding_mode (SpeculativeDecodingMode): Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.). Defaults to SpeculativeDecodingMode.NONE.
|
||||
use_refit (bool): Whether to enable engine refitting capabilities. Defaults to False.
|
||||
input_timing_cache (str, optional): Path to input timing cache file. If None, no input cache used. Defaults to None.
|
||||
output_timing_cache (str): Path to output timing cache file. Defaults to 'model.cache'.
|
||||
lora_config (LoraConfig): Configuration for LoRA (Low-Rank Adaptation) fine-tuning. Defaults to default LoraConfig.
|
||||
weight_sparsity (bool): Whether to enable weight sparsity optimization. Defaults to False.
|
||||
weight_streaming (bool): Whether to enable weight streaming for large models. Defaults to False.
|
||||
plugin_config (PluginConfig): Configuration for TensorRT LLM plugins. Defaults to default PluginConfig.
|
||||
use_strip_plan (bool): Whether to use stripped plan for engine building. Defaults to False.
|
||||
max_encoder_input_len (int): Maximum encoder input length for encoder-decoder models. Defaults to 1024.
|
||||
dry_run (bool): Whether to perform a dry run without actually building the engine. Defaults to False.
|
||||
visualize_network (str, optional): Path to save network visualization. If None, no visualization generated. Defaults to None.
|
||||
monitor_memory (bool): Whether to monitor memory usage during building. Defaults to False.
|
||||
use_mrope (bool): Whether to use Multi-RoPE (Rotary Position Embedding) optimization. Defaults to False.
|
||||
"""
|
||||
max_input_len: int = 1024
|
||||
max_seq_len: int = None
|
||||
opt_batch_size: int = 8
|
||||
max_batch_size: int = 2048
|
||||
max_beam_width: int = 1
|
||||
max_num_tokens: int = 8192
|
||||
opt_num_tokens: Optional[int] = None
|
||||
max_prompt_embedding_table_size: int = 0
|
||||
kv_cache_type: KVCacheType = None
|
||||
gather_context_logits: int = False
|
||||
gather_generation_logits: int = False
|
||||
strongly_typed: bool = True
|
||||
force_num_profiles: Optional[int] = None
|
||||
profiling_verbosity: str = 'layer_names_only'
|
||||
enable_debug_output: bool = False
|
||||
max_draft_len: int = 0
|
||||
speculative_decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
|
||||
use_refit: bool = False
|
||||
input_timing_cache: str = None
|
||||
output_timing_cache: str = 'model.cache'
|
||||
lora_config: LoraConfig = field(default_factory=LoraConfig)
|
||||
weight_sparsity: bool = False
|
||||
weight_streaming: bool = False
|
||||
plugin_config: PluginConfig = field(default_factory=PluginConfig)
|
||||
use_strip_plan: bool = False
|
||||
max_encoder_input_len: int = 1024 # for enc-dec DecoderModel
|
||||
dry_run: bool = False
|
||||
visualize_network: str = None
|
||||
monitor_memory: bool = False
|
||||
use_mrope: bool = False
|
||||
max_input_len: int = Field(default=1024,
|
||||
description="Maximum length of input sequences.")
|
||||
max_seq_len: Optional[int] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"The maximum possible sequence length for a single request, including both input and generated "
|
||||
"output tokens.")
|
||||
opt_batch_size: int = Field(
|
||||
default=8, description="Optimal batch size for engine optimization.")
|
||||
max_batch_size: int = Field(
|
||||
default=2048, description="Maximum batch size the engine can handle.")
|
||||
max_beam_width: int = Field(
|
||||
default=1, description="Maximum beam width for beam search decoding.")
|
||||
max_num_tokens: int = Field(
|
||||
default=8192,
|
||||
description="Maximum number of batched input tokens after padding is "
|
||||
"removed in each batch.")
|
||||
opt_num_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Optimal number of batched input tokens for engine optimization.")
|
||||
max_prompt_embedding_table_size: int = Field(
|
||||
default=0,
|
||||
description="Maximum size of prompt embedding table for prompt tuning.")
|
||||
kv_cache_type: Optional[KVCacheType] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Type of KV cache to use (CONTINUOUS or PAGED). If None, defaults to PAGED."
|
||||
)
|
||||
gather_context_logits: bool = Field(
|
||||
default=False,
|
||||
description="Whether to gather logits during context phase.")
|
||||
gather_generation_logits: bool = Field(
|
||||
default=False,
|
||||
description="Whether to gather logits during generation phase.")
|
||||
strongly_typed: bool = Field(default=True,
|
||||
description="Whether to use strongly_typed.")
|
||||
force_num_profiles: Optional[int] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Force a specific number of optimization profiles. If None, auto-determined."
|
||||
)
|
||||
profiling_verbosity: str = Field(
|
||||
default='layer_names_only',
|
||||
description=
|
||||
"Verbosity level for TensorRT profiling ('layer_names_only', 'detailed', 'none')."
|
||||
)
|
||||
enable_debug_output: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable debug output during building.")
|
||||
max_draft_len: int = Field(
|
||||
default=0,
|
||||
description="Maximum length of draft tokens for speculative decoding.")
|
||||
speculative_decoding_mode: SpeculativeDecodingMode = Field(
|
||||
default=SpeculativeDecodingMode.NONE,
|
||||
description="Mode for speculative decoding (NONE, MEDUSA, EAGLE, etc.)."
|
||||
)
|
||||
use_refit: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable engine refitting capabilities.")
|
||||
input_timing_cache: Optional[str] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Path to input timing cache file. If None, no input cache used.")
|
||||
output_timing_cache: str = Field(
|
||||
default='model.cache', description="Path to output timing cache file.")
|
||||
lora_config: LoraConfig = Field(
|
||||
default_factory=LoraConfig,
|
||||
description="Configuration for LoRA (Low-Rank Adaptation) fine-tuning.")
|
||||
weight_sparsity: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable weight sparsity optimization.")
|
||||
weight_streaming: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable weight streaming for large models.")
|
||||
plugin_config: PluginConfig = Field(
|
||||
default_factory=PluginConfig,
|
||||
description="Configuration for TensorRT LLM plugins.")
|
||||
use_strip_plan: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use stripped plan for engine building.")
|
||||
max_encoder_input_len: int = Field(
|
||||
default=1024,
|
||||
description="Maximum encoder input length for encoder-decoder models.")
|
||||
dry_run: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"Whether to perform a dry run without actually building the engine.")
|
||||
visualize_network: Optional[str] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"Path to save network visualization. If None, no visualization generated."
|
||||
)
|
||||
monitor_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether to monitor memory usage during building.")
|
||||
use_mrope: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"Whether to use Multi-RoPE (Rotary Position Embedding) optimization.")
|
||||
|
||||
# Since we have some overlapping between kv_cache_type, paged_kv_cache, and paged_state (later two will be deprecated in the future),
|
||||
# we need to handle it given model architecture.
|
||||
@ -574,144 +605,10 @@ class BuildConfig:
|
||||
override_attri('paged_state', False)
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def get_build_config_defaults(cls):
|
||||
return {
|
||||
field.name: field.default
|
||||
for field in dataclasses.fields(cls)
|
||||
if field.default is not dataclasses.MISSING
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config, plugin_config=None):
|
||||
config = copy.deepcopy(
|
||||
config
|
||||
) # it just does not make sense to change the input arg `config`
|
||||
|
||||
defaults = cls.get_build_config_defaults()
|
||||
max_input_len = config.pop('max_input_len',
|
||||
defaults.get('max_input_len'))
|
||||
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
|
||||
max_batch_size = config.pop('max_batch_size',
|
||||
defaults.get('max_batch_size'))
|
||||
max_beam_width = config.pop('max_beam_width',
|
||||
defaults.get('max_beam_width'))
|
||||
max_num_tokens = config.pop('max_num_tokens',
|
||||
defaults.get('max_num_tokens'))
|
||||
opt_num_tokens = config.pop('opt_num_tokens',
|
||||
defaults.get('opt_num_tokens'))
|
||||
opt_batch_size = config.pop('opt_batch_size',
|
||||
defaults.get('opt_batch_size'))
|
||||
max_prompt_embedding_table_size = config.pop(
|
||||
'max_prompt_embedding_table_size',
|
||||
defaults.get('max_prompt_embedding_table_size'))
|
||||
|
||||
if "kv_cache_type" in config and config["kv_cache_type"] is not None:
|
||||
kv_cache_type = KVCacheType.from_string(config.pop('kv_cache_type'))
|
||||
else:
|
||||
kv_cache_type = None
|
||||
gather_context_logits = config.pop(
|
||||
'gather_context_logits', defaults.get('gather_context_logits'))
|
||||
gather_generation_logits = config.pop(
|
||||
'gather_generation_logits',
|
||||
defaults.get('gather_generation_logits'))
|
||||
strongly_typed = config.pop('strongly_typed',
|
||||
defaults.get('strongly_typed'))
|
||||
force_num_profiles = config.pop('force_num_profiles',
|
||||
defaults.get('force_num_profiles'))
|
||||
weight_sparsity = config.pop('weight_sparsity',
|
||||
defaults.get('weight_sparsity'))
|
||||
profiling_verbosity = config.pop('profiling_verbosity',
|
||||
defaults.get('profiling_verbosity'))
|
||||
enable_debug_output = config.pop('enable_debug_output',
|
||||
defaults.get('enable_debug_output'))
|
||||
max_draft_len = config.pop('max_draft_len',
|
||||
defaults.get('max_draft_len'))
|
||||
speculative_decoding_mode = config.pop(
|
||||
'speculative_decoding_mode',
|
||||
defaults.get('speculative_decoding_mode'))
|
||||
use_refit = config.pop('use_refit', defaults.get('use_refit'))
|
||||
input_timing_cache = config.pop('input_timing_cache',
|
||||
defaults.get('input_timing_cache'))
|
||||
output_timing_cache = config.pop('output_timing_cache',
|
||||
defaults.get('output_timing_cache'))
|
||||
lora_config = LoraConfig(**config.get('lora_config', {}))
|
||||
max_encoder_input_len = config.pop(
|
||||
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
|
||||
weight_streaming = config.pop('weight_streaming',
|
||||
defaults.get('weight_streaming'))
|
||||
use_strip_plan = config.pop('use_strip_plan',
|
||||
defaults.get('use_strip_plan'))
|
||||
|
||||
if plugin_config is None:
|
||||
plugin_config = PluginConfig()
|
||||
if "plugin_config" in config.keys():
|
||||
plugin_config = plugin_config.model_copy(
|
||||
update=config["plugin_config"], deep=True)
|
||||
|
||||
dry_run = config.pop('dry_run', defaults.get('dry_run'))
|
||||
visualize_network = config.pop('visualize_network',
|
||||
defaults.get('visualize_network'))
|
||||
monitor_memory = config.pop('monitor_memory',
|
||||
defaults.get('monitor_memory'))
|
||||
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))
|
||||
|
||||
return cls(
|
||||
max_input_len=max_input_len,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
max_beam_width=max_beam_width,
|
||||
max_num_tokens=max_num_tokens,
|
||||
opt_num_tokens=opt_num_tokens,
|
||||
opt_batch_size=opt_batch_size,
|
||||
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
|
||||
kv_cache_type=kv_cache_type,
|
||||
gather_context_logits=gather_context_logits,
|
||||
gather_generation_logits=gather_generation_logits,
|
||||
strongly_typed=strongly_typed,
|
||||
force_num_profiles=force_num_profiles,
|
||||
profiling_verbosity=profiling_verbosity,
|
||||
enable_debug_output=enable_debug_output,
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_decoding_mode=speculative_decoding_mode,
|
||||
use_refit=use_refit,
|
||||
input_timing_cache=input_timing_cache,
|
||||
output_timing_cache=output_timing_cache,
|
||||
lora_config=lora_config,
|
||||
use_strip_plan=use_strip_plan,
|
||||
max_encoder_input_len=max_encoder_input_len,
|
||||
weight_sparsity=weight_sparsity,
|
||||
weight_streaming=weight_streaming,
|
||||
plugin_config=plugin_config,
|
||||
dry_run=dry_run,
|
||||
visualize_network=visualize_network,
|
||||
monitor_memory=monitor_memory,
|
||||
use_mrope=use_mrope)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, config_file, plugin_config=None):
|
||||
def from_json_file(cls, config_file):
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
return BuildConfig.from_dict(config, plugin_config=plugin_config)
|
||||
|
||||
def to_dict(self):
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
# the enum KVCacheType cannot be converted automatically
|
||||
if output.get('kv_cache_type', None) is not None:
|
||||
output['kv_cache_type'] = str(output['kv_cache_type'].name)
|
||||
output['plugin_config'] = output['plugin_config'].model_dump()
|
||||
output['lora_config'] = output['lora_config'].model_dump()
|
||||
return output
|
||||
|
||||
def update_from_dict(self, config: dict):
|
||||
for name, value in config.items():
|
||||
if not hasattr(self, name):
|
||||
raise AttributeError(
|
||||
f"{self.__class__} object has no attribute {name}")
|
||||
setattr(self, name, value)
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.update_from_dict(kwargs)
|
||||
return BuildConfig(**config)
|
||||
|
||||
|
||||
class EngineConfig:
|
||||
@ -731,11 +628,10 @@ class EngineConfig:
|
||||
def from_json_str(cls, config_str):
|
||||
config = json.loads(config_str)
|
||||
return cls(PretrainedConfig.from_dict(config['pretrained_config']),
|
||||
BuildConfig.from_dict(config['build_config']),
|
||||
config['version'])
|
||||
BuildConfig(**config['build_config']), config['version'])
|
||||
|
||||
def to_dict(self):
|
||||
build_config = self.build_config.to_dict()
|
||||
build_config = self.build_config.model_dump(mode="json")
|
||||
build_config.pop('dry_run', None) # Not an Engine Characteristic
|
||||
build_config.pop('visualize_network',
|
||||
None) # Not an Engine Characteristic
|
||||
@ -1081,7 +977,7 @@ def build(model: PretrainedModel, build_config: BuildConfig) -> Engine:
|
||||
'''
|
||||
tic = time.time()
|
||||
# avoid changing the input config
|
||||
build_config = copy.deepcopy(build_config)
|
||||
build_config = build_config.model_copy(deep=True)
|
||||
build_config.plugin_config.dtype = model.config.dtype
|
||||
build_config.update_kv_cache_type(model.config.architecture)
|
||||
|
||||
|
||||
@ -26,8 +26,8 @@ import torch
|
||||
|
||||
from tensorrt_llm._utils import (local_mpi_rank, local_mpi_size, mpi_barrier,
|
||||
mpi_comm, mpi_rank, mpi_world_size)
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.builder import BuildConfig, Engine, build
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.logger import logger, severity_map
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager
|
||||
@ -37,23 +37,6 @@ from tensorrt_llm.plugin import PluginConfig, add_plugin_argument
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
|
||||
def enum_type(enum_class):
|
||||
|
||||
def parse_enum(value):
|
||||
if isinstance(value, enum_class):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return enum_class.from_string(value)
|
||||
|
||||
valid_values = [e.name for e in enum_class]
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid value '{value}' of type {type(value).__name__}. Expected one of {valid_values}"
|
||||
)
|
||||
|
||||
return parse_enum
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
@ -92,29 +75,30 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
'--max_batch_size',
|
||||
type=int,
|
||||
default=BuildConfig.max_batch_size,
|
||||
default=BuildConfig.model_fields["max_batch_size"].default,
|
||||
help="Maximum number of requests that the engine can schedule.")
|
||||
parser.add_argument('--max_input_len',
|
||||
parser.add_argument(
|
||||
'--max_input_len',
|
||||
type=int,
|
||||
default=BuildConfig.max_input_len,
|
||||
default=BuildConfig.model_fields["max_input_len"].default,
|
||||
help="Maximum input length of one request.")
|
||||
parser.add_argument(
|
||||
'--max_seq_len',
|
||||
'--max_decoder_seq_len',
|
||||
dest='max_seq_len',
|
||||
type=int,
|
||||
default=BuildConfig.max_seq_len,
|
||||
default=BuildConfig.model_fields["max_seq_len"].default,
|
||||
help="Maximum total length of one request, including prompt and outputs. "
|
||||
"If unspecified, the value is deduced from the model config.")
|
||||
parser.add_argument(
|
||||
'--max_beam_width',
|
||||
type=int,
|
||||
default=BuildConfig.max_beam_width,
|
||||
default=BuildConfig.model_fields["max_beam_width"].default,
|
||||
help="Maximum number of beams for beam search decoding.")
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=BuildConfig.max_num_tokens,
|
||||
default=BuildConfig.model_fields["max_num_tokens"].default,
|
||||
help=
|
||||
"Maximum number of batched input tokens after padding is removed in each batch. "
|
||||
"Currently, the input padding is removed by default; "
|
||||
@ -123,7 +107,7 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
'--opt_num_tokens',
|
||||
type=int,
|
||||
default=BuildConfig.opt_num_tokens,
|
||||
default=BuildConfig.model_fields["opt_num_tokens"].default,
|
||||
help=
|
||||
"Optimal number of batched input tokens after padding is removed in each batch "
|
||||
"It equals to ``max_batch_size * max_beam_width`` by default, set this "
|
||||
@ -132,7 +116,7 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
'--max_encoder_input_len',
|
||||
type=int,
|
||||
default=BuildConfig.max_encoder_input_len,
|
||||
default=BuildConfig.model_fields["max_encoder_input_len"].default,
|
||||
help="Maximum encoder input length for enc-dec models. "
|
||||
"Set ``max_input_len`` to 1 to start generation from decoder_start_token_id of length 1."
|
||||
)
|
||||
@ -140,14 +124,15 @@ def parse_arguments():
|
||||
'--max_prompt_embedding_table_size',
|
||||
'--max_multimodal_len',
|
||||
type=int,
|
||||
default=BuildConfig.max_prompt_embedding_table_size,
|
||||
default=BuildConfig.model_fields["max_prompt_embedding_table_size"].
|
||||
default,
|
||||
help=
|
||||
"Maximum prompt embedding table size for prompt tuning, or maximum multimodal input size for multimodal models. "
|
||||
"Setting a value > 0 enables prompt tuning or multimodal input.")
|
||||
parser.add_argument(
|
||||
'--kv_cache_type',
|
||||
default=argparse.SUPPRESS,
|
||||
type=enum_type(KVCacheType),
|
||||
type=KVCacheType,
|
||||
help=
|
||||
"Set KV cache type (continuous, paged, or disabled). For disabled case, KV cache is disabled and only context phase is allowed."
|
||||
)
|
||||
@ -156,42 +141,44 @@ def parse_arguments():
|
||||
type=str,
|
||||
default=argparse.SUPPRESS,
|
||||
help=
|
||||
"Deprecated. Enabling this option is equvilient to ``--kv_cache_type paged`` for transformer based models."
|
||||
"Deprecated. Enabling this option is equivalent to ``--kv_cache_type paged`` for transformer based models."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--input_timing_cache',
|
||||
type=str,
|
||||
default=BuildConfig.input_timing_cache,
|
||||
default=BuildConfig.model_fields["input_timing_cache"].default,
|
||||
help=
|
||||
"The file path to read the timing cache. This option is ignored if the file does not exist."
|
||||
)
|
||||
parser.add_argument('--output_timing_cache',
|
||||
parser.add_argument(
|
||||
'--output_timing_cache',
|
||||
type=str,
|
||||
default=BuildConfig.output_timing_cache,
|
||||
default=BuildConfig.model_fields["output_timing_cache"].default,
|
||||
help="The file path to write the timing cache.")
|
||||
parser.add_argument(
|
||||
'--profiling_verbosity',
|
||||
type=str,
|
||||
default=BuildConfig.profiling_verbosity,
|
||||
default=BuildConfig.model_fields["profiling_verbosity"].default,
|
||||
choices=['layer_names_only', 'detailed', 'none'],
|
||||
help=
|
||||
"The profiling verbosity for the generated TensorRT engine. Setting to detailed allows inspecting tactic choices and kernel parameters."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strip_plan',
|
||||
default=BuildConfig.use_strip_plan,
|
||||
default=BuildConfig.model_fields["use_strip_plan"].default,
|
||||
action='store_true',
|
||||
help=
|
||||
"Enable stripping weights from the final TensorRT engine under the assumption that the refit weights are identical to those provided at build time."
|
||||
)
|
||||
parser.add_argument('--weight_sparsity',
|
||||
default=BuildConfig.weight_sparsity,
|
||||
parser.add_argument(
|
||||
'--weight_sparsity',
|
||||
default=BuildConfig.model_fields["weight_sparsity"].default,
|
||||
action='store_true',
|
||||
help="Enable weight sparsity.")
|
||||
parser.add_argument(
|
||||
'--weight_streaming',
|
||||
default=BuildConfig.weight_streaming,
|
||||
default=BuildConfig.model_fields["weight_streaming"].default,
|
||||
action='store_true',
|
||||
help=
|
||||
"Enable offloading weights to CPU and streaming loading at runtime.",
|
||||
@ -213,8 +200,9 @@ def parse_arguments():
|
||||
default='info',
|
||||
choices=severity_map.keys(),
|
||||
help="The logging level.")
|
||||
parser.add_argument('--enable_debug_output',
|
||||
default=BuildConfig.enable_debug_output,
|
||||
parser.add_argument(
|
||||
'--enable_debug_output',
|
||||
default=BuildConfig.model_fields["enable_debug_output"].default,
|
||||
action='store_true',
|
||||
help="Enable debug output.")
|
||||
parser.add_argument(
|
||||
@ -226,7 +214,7 @@ def parse_arguments():
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry_run',
|
||||
default=BuildConfig.dry_run,
|
||||
default=BuildConfig.model_fields["dry_run"].default,
|
||||
action='store_true',
|
||||
help=
|
||||
"Run through the build process except the actual Engine build for debugging."
|
||||
@ -519,65 +507,37 @@ def main():
|
||||
f"Overriding # of builder profiles <= {force_num_profiles_from_env}."
|
||||
)
|
||||
|
||||
build_config = BuildConfig.from_dict(
|
||||
{
|
||||
'max_input_len':
|
||||
args.max_input_len,
|
||||
'max_seq_len':
|
||||
args.max_seq_len,
|
||||
'max_batch_size':
|
||||
args.max_batch_size,
|
||||
'max_beam_width':
|
||||
args.max_beam_width,
|
||||
'max_num_tokens':
|
||||
args.max_num_tokens,
|
||||
'opt_num_tokens':
|
||||
args.opt_num_tokens,
|
||||
'max_prompt_embedding_table_size':
|
||||
args.max_prompt_embedding_table_size,
|
||||
'gather_context_logits':
|
||||
args.gather_context_logits,
|
||||
'gather_generation_logits':
|
||||
args.gather_generation_logits,
|
||||
'strongly_typed':
|
||||
True,
|
||||
'force_num_profiles':
|
||||
force_num_profiles_from_env,
|
||||
'weight_sparsity':
|
||||
args.weight_sparsity,
|
||||
'profiling_verbosity':
|
||||
args.profiling_verbosity,
|
||||
'enable_debug_output':
|
||||
args.enable_debug_output,
|
||||
'max_draft_len':
|
||||
args.max_draft_len,
|
||||
'speculative_decoding_mode':
|
||||
speculative_decoding_mode,
|
||||
'input_timing_cache':
|
||||
args.input_timing_cache,
|
||||
'output_timing_cache':
|
||||
args.output_timing_cache,
|
||||
'dry_run':
|
||||
args.dry_run,
|
||||
'visualize_network':
|
||||
args.visualize_network,
|
||||
'max_encoder_input_len':
|
||||
args.max_encoder_input_len,
|
||||
'weight_streaming':
|
||||
args.weight_streaming,
|
||||
'monitor_memory':
|
||||
args.monitor_memory,
|
||||
'use_mrope':
|
||||
(True if model_config.qwen_type == "qwen2_vl" else False)
|
||||
if hasattr(model_config, "qwen_type") else False
|
||||
},
|
||||
build_config = BuildConfig(
|
||||
max_input_len=args.max_input_len,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
opt_num_tokens=args.opt_num_tokens,
|
||||
max_prompt_embedding_table_size=args.
|
||||
max_prompt_embedding_table_size,
|
||||
kv_cache_type=getattr(args, "kv_cache_type", None),
|
||||
gather_context_logits=args.gather_context_logits,
|
||||
gather_generation_logits=args.gather_generation_logits,
|
||||
strongly_typed=True,
|
||||
force_num_profiles=force_num_profiles_from_env,
|
||||
weight_sparsity=args.weight_sparsity,
|
||||
profiling_verbosity=args.profiling_verbosity,
|
||||
enable_debug_output=args.enable_debug_output,
|
||||
max_draft_len=args.max_draft_len,
|
||||
speculative_decoding_mode=speculative_decoding_mode,
|
||||
input_timing_cache=args.input_timing_cache,
|
||||
output_timing_cache=args.output_timing_cache,
|
||||
dry_run=args.dry_run,
|
||||
visualize_network=args.visualize_network,
|
||||
max_encoder_input_len=args.max_encoder_input_len,
|
||||
weight_streaming=args.weight_streaming,
|
||||
monitor_memory=args.monitor_memory,
|
||||
use_mrope=getattr(model_config, "qwen_type", None) == "qwen2_vl",
|
||||
plugin_config=plugin_config)
|
||||
|
||||
if hasattr(args, 'kv_cache_type'):
|
||||
build_config.update_from_dict({'kv_cache_type': args.kv_cache_type})
|
||||
else:
|
||||
build_config = BuildConfig.from_json_file(args.build_config,
|
||||
plugin_config=plugin_config)
|
||||
build_config = BuildConfig.from_json_file(args.build_config)
|
||||
build_config.plugin_config = plugin_config
|
||||
|
||||
parallel_build(model_config, ckpt_dir, build_config, args.output_dir,
|
||||
workers, args.log_level, model_cls, **kwargs)
|
||||
|
||||
@ -50,23 +50,23 @@ from ..logger import logger, severity_map
|
||||
help="The logging level.")
|
||||
@click.option("--max_beam_width",
|
||||
type=int,
|
||||
default=BuildConfig.max_beam_width,
|
||||
default=BuildConfig.model_fields["max_beam_width"].default,
|
||||
help="Maximum number of beams for beam search decoding.")
|
||||
@click.option("--max_batch_size",
|
||||
type=int,
|
||||
default=BuildConfig.max_batch_size,
|
||||
default=BuildConfig.model_fields["max_batch_size"].default,
|
||||
help="Maximum number of requests that the engine can schedule.")
|
||||
@click.option(
|
||||
"--max_num_tokens",
|
||||
type=int,
|
||||
default=BuildConfig.max_num_tokens,
|
||||
default=BuildConfig.model_fields["max_num_tokens"].default,
|
||||
help=
|
||||
"Maximum number of batched input tokens after padding is removed in each batch."
|
||||
)
|
||||
@click.option(
|
||||
"--max_seq_len",
|
||||
type=int,
|
||||
default=BuildConfig.max_seq_len,
|
||||
default=BuildConfig.model_fields["max_seq_len"].default,
|
||||
help="Maximum total length of one request, including prompt and outputs. "
|
||||
"If unspecified, the value is deduced from the model config.")
|
||||
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
|
||||
'''
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -57,7 +56,7 @@ def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
|
||||
|
||||
# There are weights preprocess during optimize model.
|
||||
tik = time.time()
|
||||
build_config = copy.deepcopy(engine_config.build_config)
|
||||
build_config = engine_config.build_config.model_copy(deep=True)
|
||||
optimize_model_with_config(model, build_config)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
|
||||
@ -75,13 +75,17 @@ def _signal_handler_cleanup_child(signum, frame):
|
||||
sys.exit(128 + signum)
|
||||
|
||||
|
||||
def get_llm_args(model: str,
|
||||
def get_llm_args(
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
backend: str = "pytorch",
|
||||
max_beam_width: int = BuildConfig.max_beam_width,
|
||||
max_batch_size: int = BuildConfig.max_batch_size,
|
||||
max_num_tokens: int = BuildConfig.max_num_tokens,
|
||||
max_seq_len: int = BuildConfig.max_seq_len,
|
||||
max_beam_width: int = BuildConfig.model_fields["max_beam_width"].
|
||||
default,
|
||||
max_batch_size: int = BuildConfig.model_fields["max_batch_size"].
|
||||
default,
|
||||
max_num_tokens: int = BuildConfig.model_fields["max_num_tokens"].
|
||||
default,
|
||||
max_seq_len: int = BuildConfig.model_fields["max_seq_len"].default,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
moe_expert_parallel_size: Optional[int] = None,
|
||||
@ -242,23 +246,23 @@ class ChoiceWithAlias(click.Choice):
|
||||
help="The logging level.")
|
||||
@click.option("--max_beam_width",
|
||||
type=int,
|
||||
default=BuildConfig.max_beam_width,
|
||||
default=BuildConfig.model_fields["max_beam_width"].default,
|
||||
help="Maximum number of beams for beam search decoding.")
|
||||
@click.option("--max_batch_size",
|
||||
type=int,
|
||||
default=BuildConfig.max_batch_size,
|
||||
default=BuildConfig.model_fields["max_batch_size"].default,
|
||||
help="Maximum number of requests that the engine can schedule.")
|
||||
@click.option(
|
||||
"--max_num_tokens",
|
||||
type=int,
|
||||
default=BuildConfig.max_num_tokens,
|
||||
default=BuildConfig.model_fields["max_num_tokens"].default,
|
||||
help=
|
||||
"Maximum number of batched input tokens after padding is removed in each batch."
|
||||
)
|
||||
@click.option(
|
||||
"--max_seq_len",
|
||||
type=int,
|
||||
default=BuildConfig.max_seq_len,
|
||||
default=BuildConfig.model_fields["max_seq_len"].default,
|
||||
help="Maximum total length of one request, including prompt and outputs. "
|
||||
"If unspecified, the value is deduced from the model config.")
|
||||
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')
|
||||
@ -436,7 +440,7 @@ def serve(
|
||||
help="The logging level.")
|
||||
@click.option("--max_batch_size",
|
||||
type=int,
|
||||
default=BuildConfig.max_batch_size,
|
||||
default=BuildConfig.model_fields["max_batch_size"].default,
|
||||
help="Maximum number of requests that the engine can schedule.")
|
||||
@click.option(
|
||||
"--max_num_tokens",
|
||||
|
||||
@ -104,7 +104,7 @@ class BuildCache:
|
||||
Get the build step for engine building.
|
||||
'''
|
||||
build_config_str = json.dumps(self.prune_build_config_for_cache_key(
|
||||
build_config.to_dict()),
|
||||
build_config.model_dump(mode="json")),
|
||||
sort_keys=True)
|
||||
|
||||
kwargs_str = json.dumps(kwargs, sort_keys=True)
|
||||
|
||||
50
tensorrt_llm/llmapi/kv_cache_type.py
Normal file
50
tensorrt_llm/llmapi/kv_cache_type.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tensorrt_llm.bindings as _bindings
|
||||
|
||||
|
||||
class KVCacheType(str, Enum):
|
||||
"""Python enum wrapper for KVCacheType.
|
||||
|
||||
This is a pure Python enum that mirrors the C++ KVCacheType enum exposed
|
||||
through pybind11.
|
||||
"""
|
||||
CONTINUOUS = "continuous"
|
||||
PAGED = "paged"
|
||||
DISABLED = "disabled"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
"""Allow case-insensitive string values to be converted to enum members."""
|
||||
if isinstance(value, str):
|
||||
for member in cls:
|
||||
if member.value.lower() == value.lower():
|
||||
return member
|
||||
return None
|
||||
|
||||
def to_cpp(self) -> '_bindings.KVCacheType':
|
||||
import tensorrt_llm.bindings as _bindings
|
||||
return getattr(_bindings.KVCacheType, self.name)
|
||||
|
||||
@classmethod
|
||||
def from_cpp(cls, cpp_enum) -> 'KVCacheType':
|
||||
# C++ enum's __str__ returns "KVCacheType.PAGED", extract the name
|
||||
name = str(cpp_enum).split('.')[-1]
|
||||
return cls(name)
|
||||
@ -1,5 +1,4 @@
|
||||
import ast
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
@ -1764,17 +1763,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
ret = cls(**kwargs)
|
||||
return ret
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Dump `LlmArgs` instance to a dict.
|
||||
|
||||
Returns:
|
||||
dict: The dict that contains all fields of the `LlmArgs` instance.
|
||||
"""
|
||||
model_dict = self.model_dump(mode='json')
|
||||
# TODO: the BuildConfig.to_dict and from_dict don't work well with pydantic
|
||||
model_dict['build_config'] = copy.deepcopy(self.build_config)
|
||||
return model_dict
|
||||
|
||||
@staticmethod
|
||||
def _check_consistency(kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# max_beam_width is not included since vague behavior due to lacking the support for dynamic beam width during
|
||||
@ -1919,10 +1907,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
if self.max_input_len:
|
||||
kwargs["max_input_len"] = self.max_input_len
|
||||
self.build_config = BuildConfig(**kwargs)
|
||||
else:
|
||||
assert isinstance(
|
||||
build_config,
|
||||
BuildConfig), f"build_config is not initialized: {build_config}"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@ -2001,7 +1985,7 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
# TODO: remove the checker when manage weights support all data types
|
||||
if is_trt_llm_args and self.fast_build and (self.quant_config.quant_algo
|
||||
is QuantAlgo.FP8):
|
||||
self._update_plugin_config("manage_weights", True)
|
||||
self.build_config.plugin_config.manage_weights = True
|
||||
|
||||
if self.parallel_config.world_size == 1 and self.build_config:
|
||||
self.build_config.plugin_config.nccl_plugin = None
|
||||
@ -2166,9 +2150,6 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
"while LoRA prefetch is not supported")
|
||||
return self
|
||||
|
||||
def _update_plugin_config(self, key: str, value: Any):
|
||||
setattr(self.build_config.plugin_config, key, value)
|
||||
|
||||
def _load_config_from_engine(self, engine_dir: Path):
|
||||
engine_config = EngineConfig.from_json_file(engine_dir / "config.json")
|
||||
self._pretrained_config = engine_config.pretrained_config
|
||||
@ -2271,10 +2252,8 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
fast_build: bool = Field(default=False, description="Enable fast build.")
|
||||
|
||||
# BuildConfig is introduced to give users a familiar interface to configure the model building.
|
||||
build_config: Optional[object] = Field(
|
||||
default=None,
|
||||
description="Build config.",
|
||||
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"})
|
||||
build_config: Optional[BuildConfig] = Field(default=None,
|
||||
description="Build config.")
|
||||
|
||||
# Prompt adapter arguments
|
||||
enable_prompt_adapter: bool = Field(default=False,
|
||||
@ -2405,11 +2384,10 @@ class TorchCompileConfig(StrictBaseModel):
|
||||
|
||||
class TorchLlmArgs(BaseLlmArgs):
|
||||
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
|
||||
build_config: Optional[object] = Field(
|
||||
build_config: Optional[BuildConfig] = Field(
|
||||
default=None,
|
||||
description="Build config.",
|
||||
exclude_from_json=True,
|
||||
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"},
|
||||
status="deprecated",
|
||||
)
|
||||
|
||||
@ -2911,10 +2889,7 @@ def update_llm_args_with_extra_dict(
|
||||
for field_name, field_type in field_mapping.items():
|
||||
if field_name in llm_args_dict:
|
||||
# Some fields need to be converted manually.
|
||||
if field_name in [
|
||||
"speculative_config", "build_config",
|
||||
"sparse_attention_config"
|
||||
]:
|
||||
if field_name in ["speculative_config", "sparse_attention_config"]:
|
||||
llm_args_dict[field_name] = field_type.from_dict(
|
||||
llm_args_dict[field_name])
|
||||
else:
|
||||
@ -2928,6 +2903,10 @@ def update_llm_args_with_extra_dict(
|
||||
# For trtllm-bench or trtllm-serve, build_config may be passed for the PyTorch
|
||||
# backend, overwriting the knobs there since build_config always has the highest priority
|
||||
if "build_config" in llm_args:
|
||||
# Ensure build_config is a BuildConfig object, not a dict
|
||||
if isinstance(llm_args["build_config"], dict):
|
||||
llm_args["build_config"] = BuildConfig(**llm_args["build_config"])
|
||||
|
||||
for key in [
|
||||
"max_batch_size",
|
||||
"max_num_tokens",
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -530,8 +529,8 @@ class ModelLoader:
|
||||
|
||||
logger_debug(f"rank{mpi_rank()} begin to build engine...\n", "green")
|
||||
|
||||
# avoid the original build_config is modified, avoid the side effect
|
||||
copied_build_config = copy.deepcopy(self.build_config)
|
||||
# avoid side effects by copying the original build_config
|
||||
copied_build_config = self.build_config.model_copy(deep=True)
|
||||
|
||||
copied_build_config.update_kv_cache_type(self._model_info.architecture)
|
||||
assert self.model is not None, "model is loaded yet."
|
||||
|
||||
@ -27,10 +27,10 @@ from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
|
||||
|
||||
from ..._common import default_net, default_trtnet
|
||||
from ..._utils import pad_vocab_size
|
||||
from ...bindings import KVCacheType
|
||||
from ...functional import (Tensor, _create_tensor, cast, concat,
|
||||
gather_last_token_logits, index_select, shape)
|
||||
from ...layers import AttentionParams, ColumnLinear, SpecDecodingParams
|
||||
from ...llmapi.kv_cache_type import KVCacheType
|
||||
from ...module import Module, ModuleList
|
||||
from ...plugin import TRT_LLM_PLUGIN_NAMESPACE
|
||||
from ..modeling_utils import QuantConfig
|
||||
|
||||
@ -18,9 +18,9 @@ from typing import List, Optional
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
from ..bindings import KVCacheType
|
||||
from ..functional import Tensor
|
||||
from ..layers import MropeParams, SpecDecodingParams
|
||||
from ..llmapi.kv_cache_type import KVCacheType
|
||||
from ..mapping import Mapping
|
||||
from ..plugin import current_all_reduce_helper
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ import torch
|
||||
|
||||
from tensorrt_llm._common import default_net
|
||||
from tensorrt_llm._utils import numpy_to_torch, str_dtype_to_torch
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.functional import (Conditional, LayerNormPositionType,
|
||||
LayerNormType, MLPType,
|
||||
PositionEmbeddingType, Tensor, assertion,
|
||||
@ -32,6 +31,7 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
|
||||
ColumnLinear, Embedding, FusedGatedMLP,
|
||||
GatedMLP, GroupNorm, KeyValueCacheParams,
|
||||
LayerNorm, LoraParams, RmsNorm)
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
use_lora)
|
||||
|
||||
@ -19,7 +19,6 @@ from .._common import default_net
|
||||
from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch,
|
||||
release_gc, str_dtype_to_torch, str_dtype_to_trt,
|
||||
trt_dtype_to_torch)
|
||||
from ..bindings import KVCacheType
|
||||
from ..bindings.executor import RuntimeDefaults
|
||||
from ..functional import (PositionEmbeddingType, Tensor, allgather, constant,
|
||||
cp_split_plugin, gather_last_token_logits,
|
||||
@ -31,6 +30,7 @@ from ..layers.attention import Attention, BertAttention
|
||||
from ..layers.linear import ColumnLinear, Linear, RowLinear
|
||||
from ..layers.lora import Dora, Lora
|
||||
from ..layers.moe import MOE, MoeOOTB
|
||||
from ..llmapi.kv_cache_type import KVCacheType
|
||||
from ..logger import logger
|
||||
from ..mapping import Mapping
|
||||
from ..module import Module, ModuleList
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AttentionMaskType, PositionEmbeddingType,
|
||||
Tensor, gather_last_token_logits, recv,
|
||||
@ -28,6 +27,7 @@ from tensorrt_llm.layers.linear import ColumnLinear
|
||||
from tensorrt_llm.layers.lora import LoraParams
|
||||
from tensorrt_llm.layers.mlp import GatedMLP
|
||||
from tensorrt_llm.layers.normalization import RmsNorm
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import has_safetensors
|
||||
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM
|
||||
|
||||
@ -18,8 +18,8 @@ from collections import OrderedDict
|
||||
import tensorrt as trt
|
||||
|
||||
from tensorrt_llm._common import default_net
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.functional import Tensor, cast, categorical_sample
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.models import LLaMAForCausalLM, QWenForCausalLM
|
||||
from tensorrt_llm.models.generation_mixin import GenerationMixin
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ from tqdm import tqdm
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._common import default_net
|
||||
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_str
|
||||
from tensorrt_llm.bindings import KVCacheType
|
||||
from tensorrt_llm.functional import (ACT2FN, AttentionMaskType, LayerNormType,
|
||||
PositionEmbeddingType, Tensor,
|
||||
constant_to_tensor_)
|
||||
@ -41,6 +40,7 @@ from tensorrt_llm.layers.attention import (Attention, AttentionParams,
|
||||
BertAttention, KeyValueCacheParams,
|
||||
bert_attention, layernorm_map)
|
||||
from tensorrt_llm.layers.normalization import RmsNorm
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.generation_mixin import GenerationMixin
|
||||
from tensorrt_llm.models.model_weights_loader import (ModelWeightsFormat,
|
||||
|
||||
@ -43,8 +43,9 @@ from tensorrt_llm.runtime.redrafter_utils import *
|
||||
from .._utils import (binding_layer_type_to_str, binding_to_str_dtype,
|
||||
pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
|
||||
trt_dtype_to_torch)
|
||||
from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free
|
||||
from ..bindings import ipc_nvls_allocate, ipc_nvls_free
|
||||
from ..layers import LanguageAdapterConfig
|
||||
from ..llmapi.kv_cache_type import KVCacheType
|
||||
from ..logger import logger
|
||||
from ..lora_manager import LoraManager
|
||||
from ..mapping import Mapping
|
||||
|
||||
@ -25,9 +25,10 @@ import torch
|
||||
|
||||
from .. import profiler
|
||||
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch
|
||||
from ..bindings import KVCacheType, MpiComm
|
||||
from ..bindings import MpiComm
|
||||
from ..bindings.executor import Executor
|
||||
from ..builder import Engine, EngineConfig, get_engine_version
|
||||
from ..llmapi.kv_cache_type import KVCacheType
|
||||
from ..logger import logger
|
||||
from ..mapping import Mapping
|
||||
from ..quantization import QuantMode
|
||||
@ -86,7 +87,9 @@ def _builder_to_model_config(config: dict) -> Tuple[ModelConfig, dict]:
|
||||
dtype = builder_config['precision']
|
||||
tp_size = builder_config['tensor_parallel']
|
||||
pp_size = builder_config.get('pipeline_parallel', 1)
|
||||
kv_cache_type = KVCacheType.from_string(builder_config.get('kv_cache_type'))
|
||||
kv_cache_type = builder_config.get('kv_cache_type')
|
||||
if kv_cache_type is not None:
|
||||
kv_cache_type = KVCacheType(kv_cache_type)
|
||||
world_size = tp_size * pp_size
|
||||
assert world_size == mpi_world_size(), \
|
||||
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({mpi_world_size()})'
|
||||
|
||||
@ -23,13 +23,13 @@ import torch
|
||||
|
||||
from .. import profiler
|
||||
from .._utils import mpi_broadcast
|
||||
from ..bindings import (DataType, GptJsonConfig, KVCacheType, ModelConfig,
|
||||
WorldConfig)
|
||||
from ..bindings import DataType, GptJsonConfig, ModelConfig, WorldConfig
|
||||
from ..bindings import executor as trtllm
|
||||
from ..bindings.executor import (DecodingMode, ExternalDraftTokensConfig,
|
||||
OrchestratorConfig, ParallelConfig)
|
||||
from ..builder import EngineConfig
|
||||
from ..layers import MropeParams
|
||||
from ..llmapi.kv_cache_type import KVCacheType
|
||||
from ..logger import logger
|
||||
from ..mapping import Mapping
|
||||
from .generation import LogitsProcessor, LoraManager
|
||||
@ -248,7 +248,8 @@ class ModelRunnerCpp(ModelRunnerMixin):
|
||||
json_config = GptJsonConfig.parse_file(config_path)
|
||||
model_config = json_config.model_config
|
||||
|
||||
use_kv_cache = model_config.kv_cache_type != KVCacheType.DISABLED
|
||||
use_kv_cache = KVCacheType.from_cpp(
|
||||
model_config.kv_cache_type) != KVCacheType.DISABLED
|
||||
if not model_config.use_cross_attention:
|
||||
assert cross_kv_cache_fraction is None, "cross_kv_cache_fraction should only be used with enc-dec models."
|
||||
|
||||
|
||||
@ -671,7 +671,8 @@ class CliFlowAccuracyTestHarness:
|
||||
f"--max_tokens_in_paged_kv_cache={max_tokens_in_paged_kv_cache}"
|
||||
])
|
||||
|
||||
if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.max_num_tokens:
|
||||
if task.MAX_INPUT_LEN + task.MAX_OUTPUT_LEN > BuildConfig.model_fields[
|
||||
"max_num_tokens"].default:
|
||||
summarize_cmd.append("--enable_chunked_context")
|
||||
|
||||
if self.extra_summarize_args:
|
||||
|
||||
@ -142,14 +142,14 @@ def test_llmapi_build_command_parameters_align(llm_root, llm_venv, engine_dir,
|
||||
with open(os.path.join(engine_dir, "config.json"), "r") as f:
|
||||
engine_config = json.load(f)
|
||||
|
||||
build_cmd_cfg = BuildConfig.from_dict(
|
||||
engine_config["build_config"]).to_dict()
|
||||
build_cmd_cfg = BuildConfig(
|
||||
**engine_config["build_config"]).model_dump()
|
||||
|
||||
with open(os.path.join(tmpdir.name, "config.json"), "r") as f:
|
||||
llm_api_engine_cfg = json.load(f)
|
||||
|
||||
build_llmapi_cfg = BuildConfig.from_dict(
|
||||
llm_api_engine_cfg["build_config"]).to_dict()
|
||||
build_llmapi_cfg = BuildConfig(
|
||||
**llm_api_engine_cfg["build_config"]).model_dump()
|
||||
|
||||
assert build_cmd_cfg == build_llmapi_cfg
|
||||
|
||||
|
||||
@ -1636,7 +1636,7 @@ def get_allowed_models(benchmark_type=None):
|
||||
if i.benchmark_type == benchmark_type)
|
||||
|
||||
|
||||
def get_build_config(model_name, return_dict=True) -> Union[BuildConfig]:
|
||||
def get_build_config(model_name, return_dict=True) -> Union[Dict, BuildConfig]:
|
||||
if model_name in _allowed_configs:
|
||||
cfg = _allowed_configs[model_name].build_config
|
||||
return asdict(cfg) if return_dict else cfg
|
||||
|
||||
@ -255,7 +255,7 @@ def get_quant_config(quantization: str):
|
||||
|
||||
def build_gpt(args):
|
||||
build_config = get_build_config(args.model)
|
||||
build_config = BuildConfig.from_dict(build_config)
|
||||
build_config = BuildConfig(**build_config)
|
||||
model_config = get_model_config(args.model)
|
||||
if args.force_num_layer_1:
|
||||
model_config['num_layers'] = 1
|
||||
@ -1448,7 +1448,7 @@ def enc_dec_build_helper(component, build_config, model_config, args):
|
||||
|
||||
def build_enc_dec(args):
|
||||
build_config = get_build_config(args.model)
|
||||
build_config = BuildConfig.from_dict(build_config)
|
||||
build_config = BuildConfig(**build_config)
|
||||
model_config = get_model_config(args.model)
|
||||
if args.force_num_layer_1:
|
||||
model_config['num_layers'] = 1
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from utils.runtime_defaults import assert_runtime_defaults_are_parsed_correctly
|
||||
|
||||
import tensorrt_llm.bindings as _tb
|
||||
from tensorrt_llm.llmapi.kv_cache_type import KVCacheType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
@ -85,6 +86,7 @@ def test_model_config():
|
||||
assert model_config.use_packed_input
|
||||
|
||||
assert model_config.kv_cache_type is not None
|
||||
# Test with C++ enums directly
|
||||
for enum_val in [
|
||||
_tb.KVCacheType.CONTINUOUS, _tb.KVCacheType.PAGED,
|
||||
_tb.KVCacheType.DISABLED
|
||||
@ -92,6 +94,17 @@ def test_model_config():
|
||||
model_config.kv_cache_type = enum_val
|
||||
assert model_config.kv_cache_type == enum_val
|
||||
|
||||
# Test with Python enums converted to C++
|
||||
for py_enum in [
|
||||
KVCacheType.CONTINUOUS, KVCacheType.PAGED, KVCacheType.DISABLED
|
||||
]:
|
||||
model_config.kv_cache_type = py_enum.to_cpp()
|
||||
# Verify it was set correctly by comparing with C++ enum
|
||||
assert model_config.kv_cache_type == getattr(_tb.KVCacheType,
|
||||
py_enum.name)
|
||||
# Also verify round-trip conversion works
|
||||
assert KVCacheType.from_cpp(model_config.kv_cache_type) == py_enum
|
||||
|
||||
assert model_config.tokens_per_block == 64
|
||||
tokens_per_block = 1024
|
||||
model_config.tokens_per_block = tokens_per_block
|
||||
|
||||
@ -202,7 +202,7 @@ def test_llm_build_config():
|
||||
# read the build_config and check if the parameters are correctly saved
|
||||
engine_config = json.load(f)
|
||||
|
||||
build_config1 = BuildConfig.from_dict(engine_config["build_config"])
|
||||
build_config1 = BuildConfig(**engine_config["build_config"])
|
||||
|
||||
# Know issue: this will be converted to None after save engine for single-gpu
|
||||
build_config1.plugin_config.nccl_plugin = 'float16'
|
||||
|
||||
@ -64,7 +64,7 @@ speculative_config:
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.speculative_config.max_window_size == 4
|
||||
@ -80,7 +80,7 @@ pytorch_backend_config: # this is deprecated
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
with pytest.raises(ValueError):
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
@ -96,7 +96,7 @@ build_config:
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.build_config.max_beam_width == 4
|
||||
@ -113,7 +113,7 @@ kv_cache_config:
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.kv_cache_config.enable_block_reuse == True
|
||||
@ -131,7 +131,7 @@ max_seq_len: 128
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.model_dump(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.max_batch_size == 16
|
||||
@ -331,7 +331,7 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
|
||||
lora_config=LoraConfig(lora_ckpt_source='hf'),
|
||||
plugin_config=plugin_config)
|
||||
extra_llm_args_dict = {
|
||||
"build_config": build_config.to_dict(),
|
||||
"build_config": build_config.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
|
||||
@ -352,8 +352,9 @@ def test_update_llm_args_with_extra_dict_with_nested_dict():
|
||||
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
|
||||
return True
|
||||
|
||||
build_config_dict1 = build_config.to_dict()
|
||||
build_config_dict2 = initialized_llm_args.build_config.to_dict()
|
||||
build_config_dict1 = build_config.model_dump(mode="json")
|
||||
build_config_dict2 = initialized_llm_args.build_config.model_dump(
|
||||
mode="json")
|
||||
check_nested_dict_equality(build_config_dict1, build_config_dict2)
|
||||
|
||||
|
||||
@ -498,11 +499,11 @@ class TestTrtLlmArgs:
|
||||
max_num_tokens=256,
|
||||
)
|
||||
args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
|
||||
args_dict = args.to_dict()
|
||||
args_dict = args.model_dump()
|
||||
|
||||
new_args = TrtLlmArgs.from_kwargs(**args_dict)
|
||||
|
||||
assert new_args.to_dict() == args_dict
|
||||
assert new_args.model_dump() == args_dict
|
||||
|
||||
def test_build_config_from_engine(self):
|
||||
build_config = BuildConfig(max_batch_size=8, max_num_tokens=256)
|
||||
@ -522,6 +523,36 @@ class TestTrtLlmArgs:
|
||||
assert args.max_num_tokens == 16
|
||||
assert args.max_batch_size == 4
|
||||
|
||||
def test_model_dump_does_not_mutate_original(self):
|
||||
"""Test that model_dump() and update_llm_args_with_extra_dict don't mutate the original."""
|
||||
# Create args with specific build_config values
|
||||
build_config = BuildConfig(
|
||||
max_batch_size=8,
|
||||
max_num_tokens=256,
|
||||
)
|
||||
args = TrtLlmArgs(model=llama_model_path, build_config=build_config)
|
||||
|
||||
# Store original values
|
||||
original_max_batch_size = args.build_config.max_batch_size
|
||||
original_max_num_tokens = args.build_config.max_num_tokens
|
||||
|
||||
# Convert to dict and pass through update_llm_args_with_extra_dict with overrides
|
||||
args_dict = args.model_dump()
|
||||
extra_dict = {
|
||||
"max_batch_size": 128,
|
||||
"max_num_tokens": 1024,
|
||||
}
|
||||
updated_dict = update_llm_args_with_extra_dict(args_dict, extra_dict)
|
||||
|
||||
# Verify original args was NOT mutated
|
||||
assert args.build_config.max_batch_size == original_max_batch_size
|
||||
assert args.build_config.max_num_tokens == original_max_num_tokens
|
||||
|
||||
# Verify updated dict has new values
|
||||
new_args = TrtLlmArgs(**updated_dict)
|
||||
assert new_args.build_config.max_batch_size == 128
|
||||
assert new_args.build_config.max_num_tokens == 1024
|
||||
|
||||
|
||||
class TestStrictBaseModelArbitraryArgs:
|
||||
"""Test that StrictBaseModel prevents arbitrary arguments from being accepted."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user