fix: Update trtllm args issues with extra nested config (#5996) (#6114)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-07-18 04:13:22 +08:00 committed by GitHub
parent 4d0bcbcb2d
commit c18b632160
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 131 additions and 32 deletions

View File

@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import dataclasses
import json
import math
import os
import shutil
import time
from dataclasses import dataclass, field
from functools import cache
from pathlib import Path
from typing import Dict, Optional, Union
@ -557,53 +559,89 @@ class BuildConfig:
else:
override_attri('paged_state', False)
@classmethod
@cache
def get_build_config_defaults(cls):
return {
field.name: field.default
for field in dataclasses.fields(cls)
if field.default is not dataclasses.MISSING
}
@classmethod
def from_dict(cls, config, plugin_config=None):
config = copy.deepcopy(
config
) # it just does not make sense to change the input arg `config`
max_input_len = config.pop('max_input_len')
max_seq_len = config.pop('max_seq_len')
max_batch_size = config.pop('max_batch_size')
max_beam_width = config.pop('max_beam_width')
max_num_tokens = config.pop('max_num_tokens')
opt_num_tokens = config.pop('opt_num_tokens')
opt_batch_size = config.pop('opt_batch_size', 8)
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size', 0)
kv_cache_type = KVCacheType(
config.pop('kv_cache_type')) if 'plugin_config' in config else None
gather_context_logits = config.pop('gather_context_logits', False)
gather_generation_logits = config.pop('gather_generation_logits', False)
strongly_typed = config.pop('strongly_typed', True)
force_num_profiles = config.pop('force_num_profiles', None)
weight_sparsity = config.pop('weight_sparsity', False)
defaults = cls.get_build_config_defaults()
max_input_len = config.pop('max_input_len',
defaults.get('max_input_len'))
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
max_batch_size = config.pop('max_batch_size',
defaults.get('max_batch_size'))
max_beam_width = config.pop('max_beam_width',
defaults.get('max_beam_width'))
max_num_tokens = config.pop('max_num_tokens',
defaults.get('max_num_tokens'))
opt_num_tokens = config.pop('opt_num_tokens',
defaults.get('opt_num_tokens'))
opt_batch_size = config.pop('opt_batch_size',
defaults.get('opt_batch_size'))
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size',
defaults.get('max_prompt_embedding_table_size'))
if "kv_cache_type" in config and config["kv_cache_type"] is not None:
kv_cache_type = KVCacheType(config.pop('kv_cache_type'))
else:
kv_cache_type = None
gather_context_logits = config.pop(
'gather_context_logits', defaults.get('gather_context_logits'))
gather_generation_logits = config.pop(
'gather_generation_logits',
defaults.get('gather_generation_logits'))
strongly_typed = config.pop('strongly_typed',
defaults.get('strongly_typed'))
force_num_profiles = config.pop('force_num_profiles',
defaults.get('force_num_profiles'))
weight_sparsity = config.pop('weight_sparsity',
defaults.get('weight_sparsity'))
profiling_verbosity = config.pop('profiling_verbosity',
'layer_names_only')
enable_debug_output = config.pop('enable_debug_output', False)
max_draft_len = config.pop('max_draft_len', 0)
speculative_decoding_mode = config.pop('speculative_decoding_mode',
SpeculativeDecodingMode.NONE)
use_refit = config.pop('use_refit', False)
input_timing_cache = config.pop('input_timing_cache', None)
output_timing_cache = config.pop('output_timing_cache', None)
defaults.get('profiling_verbosity'))
enable_debug_output = config.pop('enable_debug_output',
defaults.get('enable_debug_output'))
max_draft_len = config.pop('max_draft_len',
defaults.get('max_draft_len'))
speculative_decoding_mode = config.pop(
'speculative_decoding_mode',
defaults.get('speculative_decoding_mode'))
use_refit = config.pop('use_refit', defaults.get('use_refit'))
input_timing_cache = config.pop('input_timing_cache',
defaults.get('input_timing_cache'))
output_timing_cache = config.pop('output_timing_cache',
defaults.get('output_timing_cache'))
lora_config = LoraConfig.from_dict(config.get('lora_config', {}))
auto_parallel_config = AutoParallelConfig.from_dict(
config.get('auto_parallel_config', {}))
max_encoder_input_len = config.pop('max_encoder_input_len', 1024)
weight_streaming = config.pop('weight_streaming', False)
use_strip_plan = config.pop('use_strip_plan', False)
max_encoder_input_len = config.pop(
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
weight_streaming = config.pop('weight_streaming',
defaults.get('weight_streaming'))
use_strip_plan = config.pop('use_strip_plan',
defaults.get('use_strip_plan'))
if plugin_config is None:
plugin_config = PluginConfig()
if "plugin_config" in config.keys():
plugin_config.update_from_dict(config["plugin_config"])
dry_run = config.pop('dry_run', False)
visualize_network = config.pop('visualize_network', None)
monitor_memory = config.pop('monitor_memory', False)
use_mrope = config.pop('use_mrope', False)
dry_run = config.pop('dry_run', defaults.get('dry_run'))
visualize_network = config.pop('visualize_network',
defaults.get('visualize_network'))
monitor_memory = config.pop('monitor_memory',
defaults.get('monitor_memory'))
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))
return cls(
max_input_len=max_input_len,

View File

@ -2014,7 +2014,8 @@ def update_llm_args_with_extra_dict(
}
for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict:
if field_name == "speculative_config":
# Some fields need to be converted manually.
if field_name in ["speculative_config", "build_config"]:
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name])
else:

View File

@ -5,10 +5,15 @@ import pytest
import yaml
import tensorrt_llm.bindings.executor as tle
from tensorrt_llm import AutoParallelConfig
from tensorrt_llm._torch.llm import LLM as TorchLLM
from tensorrt_llm.builder import LoraConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
SchedulerConfig)
from tensorrt_llm.llmapi.llm import LLM
from tensorrt_llm.llmapi.llm_args import *
from tensorrt_llm.llmapi.utils import print_traceback_on_error
from tensorrt_llm.plugin import PluginConfig
from .test_llm import llama_model_path
@ -179,6 +184,61 @@ def test_PeftCacheConfig_declaration():
assert pybind_config.lora_prefetch_dir == "."
def test_update_llm_args_with_extra_dict_with_nested_dict():
llm_api_args_dict = {
"model":
"dummy-model",
"build_config":
None, # Will override later.
"extended_runtime_perf_knob_config":
ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
"kv_cache_config":
KvCacheConfig(enable_block_reuse=False),
"peft_cache_config":
PeftCacheConfig(num_host_module_layer=0),
"scheduler_config":
SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy.
GUARANTEED_NO_EVICT)
}
plugin_config_dict = {
"_dtype": 'float16',
"nccl_plugin": None,
}
plugin_config = PluginConfig.from_dict(plugin_config_dict)
build_config = BuildConfig(max_input_len=1024,
lora_config=LoraConfig(lora_ckpt_source='hf'),
auto_parallel_config=AutoParallelConfig(
world_size=1,
same_buffer_io={},
debug_outputs=[]),
plugin_config=plugin_config)
extra_llm_args_dict = {
"build_config": build_config.to_dict(),
}
llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
extra_llm_args_dict,
"build_config")
initialized_llm_args = TrtLlmArgs(**llm_api_args_dict)
def check_nested_dict_equality(dict1, dict2, path=""):
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
if dict1 != dict2:
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True
if dict1.keys() != dict2.keys():
raise ValueError(f"Different keys at {path}:")
for key in dict1:
new_path = f"{path}.{key}" if path else key
if not check_nested_dict_equality(dict1[key], dict2[key], new_path):
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True
build_config_dict1 = build_config.to_dict()
build_config_dict2 = initialized_llm_args.build_config.to_dict()
check_nested_dict_equality(build_config_dict1, build_config_dict2)
class TestTorchLlmArgsCudaGraphSettings:
def test_cuda_graph_batch_sizes_case_0(self):