mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-08 04:01:51 +08:00
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
4d0bcbcb2d
commit
c18b632160
@ -13,12 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@ -557,53 +559,89 @@ class BuildConfig:
|
||||
else:
|
||||
override_attri('paged_state', False)
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def get_build_config_defaults(cls):
|
||||
return {
|
||||
field.name: field.default
|
||||
for field in dataclasses.fields(cls)
|
||||
if field.default is not dataclasses.MISSING
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config, plugin_config=None):
|
||||
config = copy.deepcopy(
|
||||
config
|
||||
) # it just does not make sense to change the input arg `config`
|
||||
max_input_len = config.pop('max_input_len')
|
||||
max_seq_len = config.pop('max_seq_len')
|
||||
max_batch_size = config.pop('max_batch_size')
|
||||
max_beam_width = config.pop('max_beam_width')
|
||||
max_num_tokens = config.pop('max_num_tokens')
|
||||
opt_num_tokens = config.pop('opt_num_tokens')
|
||||
opt_batch_size = config.pop('opt_batch_size', 8)
|
||||
max_prompt_embedding_table_size = config.pop(
|
||||
'max_prompt_embedding_table_size', 0)
|
||||
|
||||
kv_cache_type = KVCacheType(
|
||||
config.pop('kv_cache_type')) if 'plugin_config' in config else None
|
||||
gather_context_logits = config.pop('gather_context_logits', False)
|
||||
gather_generation_logits = config.pop('gather_generation_logits', False)
|
||||
strongly_typed = config.pop('strongly_typed', True)
|
||||
force_num_profiles = config.pop('force_num_profiles', None)
|
||||
weight_sparsity = config.pop('weight_sparsity', False)
|
||||
defaults = cls.get_build_config_defaults()
|
||||
max_input_len = config.pop('max_input_len',
|
||||
defaults.get('max_input_len'))
|
||||
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
|
||||
max_batch_size = config.pop('max_batch_size',
|
||||
defaults.get('max_batch_size'))
|
||||
max_beam_width = config.pop('max_beam_width',
|
||||
defaults.get('max_beam_width'))
|
||||
max_num_tokens = config.pop('max_num_tokens',
|
||||
defaults.get('max_num_tokens'))
|
||||
opt_num_tokens = config.pop('opt_num_tokens',
|
||||
defaults.get('opt_num_tokens'))
|
||||
opt_batch_size = config.pop('opt_batch_size',
|
||||
defaults.get('opt_batch_size'))
|
||||
max_prompt_embedding_table_size = config.pop(
|
||||
'max_prompt_embedding_table_size',
|
||||
defaults.get('max_prompt_embedding_table_size'))
|
||||
|
||||
if "kv_cache_type" in config and config["kv_cache_type"] is not None:
|
||||
kv_cache_type = KVCacheType(config.pop('kv_cache_type'))
|
||||
else:
|
||||
kv_cache_type = None
|
||||
gather_context_logits = config.pop(
|
||||
'gather_context_logits', defaults.get('gather_context_logits'))
|
||||
gather_generation_logits = config.pop(
|
||||
'gather_generation_logits',
|
||||
defaults.get('gather_generation_logits'))
|
||||
strongly_typed = config.pop('strongly_typed',
|
||||
defaults.get('strongly_typed'))
|
||||
force_num_profiles = config.pop('force_num_profiles',
|
||||
defaults.get('force_num_profiles'))
|
||||
weight_sparsity = config.pop('weight_sparsity',
|
||||
defaults.get('weight_sparsity'))
|
||||
profiling_verbosity = config.pop('profiling_verbosity',
|
||||
'layer_names_only')
|
||||
enable_debug_output = config.pop('enable_debug_output', False)
|
||||
max_draft_len = config.pop('max_draft_len', 0)
|
||||
speculative_decoding_mode = config.pop('speculative_decoding_mode',
|
||||
SpeculativeDecodingMode.NONE)
|
||||
use_refit = config.pop('use_refit', False)
|
||||
input_timing_cache = config.pop('input_timing_cache', None)
|
||||
output_timing_cache = config.pop('output_timing_cache', None)
|
||||
defaults.get('profiling_verbosity'))
|
||||
enable_debug_output = config.pop('enable_debug_output',
|
||||
defaults.get('enable_debug_output'))
|
||||
max_draft_len = config.pop('max_draft_len',
|
||||
defaults.get('max_draft_len'))
|
||||
speculative_decoding_mode = config.pop(
|
||||
'speculative_decoding_mode',
|
||||
defaults.get('speculative_decoding_mode'))
|
||||
use_refit = config.pop('use_refit', defaults.get('use_refit'))
|
||||
input_timing_cache = config.pop('input_timing_cache',
|
||||
defaults.get('input_timing_cache'))
|
||||
output_timing_cache = config.pop('output_timing_cache',
|
||||
defaults.get('output_timing_cache'))
|
||||
lora_config = LoraConfig.from_dict(config.get('lora_config', {}))
|
||||
auto_parallel_config = AutoParallelConfig.from_dict(
|
||||
config.get('auto_parallel_config', {}))
|
||||
max_encoder_input_len = config.pop('max_encoder_input_len', 1024)
|
||||
weight_streaming = config.pop('weight_streaming', False)
|
||||
use_strip_plan = config.pop('use_strip_plan', False)
|
||||
max_encoder_input_len = config.pop(
|
||||
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
|
||||
weight_streaming = config.pop('weight_streaming',
|
||||
defaults.get('weight_streaming'))
|
||||
use_strip_plan = config.pop('use_strip_plan',
|
||||
defaults.get('use_strip_plan'))
|
||||
|
||||
if plugin_config is None:
|
||||
plugin_config = PluginConfig()
|
||||
if "plugin_config" in config.keys():
|
||||
plugin_config.update_from_dict(config["plugin_config"])
|
||||
|
||||
dry_run = config.pop('dry_run', False)
|
||||
visualize_network = config.pop('visualize_network', None)
|
||||
monitor_memory = config.pop('monitor_memory', False)
|
||||
use_mrope = config.pop('use_mrope', False)
|
||||
dry_run = config.pop('dry_run', defaults.get('dry_run'))
|
||||
visualize_network = config.pop('visualize_network',
|
||||
defaults.get('visualize_network'))
|
||||
monitor_memory = config.pop('monitor_memory',
|
||||
defaults.get('monitor_memory'))
|
||||
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))
|
||||
|
||||
return cls(
|
||||
max_input_len=max_input_len,
|
||||
|
||||
@ -2014,7 +2014,8 @@ def update_llm_args_with_extra_dict(
|
||||
}
|
||||
for field_name, field_type in field_mapping.items():
|
||||
if field_name in llm_args_dict:
|
||||
if field_name == "speculative_config":
|
||||
# Some fields need to be converted manually.
|
||||
if field_name in ["speculative_config", "build_config"]:
|
||||
llm_args_dict[field_name] = field_type.from_dict(
|
||||
llm_args_dict[field_name])
|
||||
else:
|
||||
|
||||
@ -5,10 +5,15 @@ import pytest
|
||||
import yaml
|
||||
|
||||
import tensorrt_llm.bindings.executor as tle
|
||||
from tensorrt_llm import AutoParallelConfig
|
||||
from tensorrt_llm._torch.llm import LLM as TorchLLM
|
||||
from tensorrt_llm.builder import LoraConfig
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
SchedulerConfig)
|
||||
from tensorrt_llm.llmapi.llm import LLM
|
||||
from tensorrt_llm.llmapi.llm_args import *
|
||||
from tensorrt_llm.llmapi.utils import print_traceback_on_error
|
||||
from tensorrt_llm.plugin import PluginConfig
|
||||
|
||||
from .test_llm import llama_model_path
|
||||
|
||||
@ -179,6 +184,61 @@ def test_PeftCacheConfig_declaration():
|
||||
assert pybind_config.lora_prefetch_dir == "."
|
||||
|
||||
|
||||
def test_update_llm_args_with_extra_dict_with_nested_dict():
|
||||
llm_api_args_dict = {
|
||||
"model":
|
||||
"dummy-model",
|
||||
"build_config":
|
||||
None, # Will override later.
|
||||
"extended_runtime_perf_knob_config":
|
||||
ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
|
||||
"kv_cache_config":
|
||||
KvCacheConfig(enable_block_reuse=False),
|
||||
"peft_cache_config":
|
||||
PeftCacheConfig(num_host_module_layer=0),
|
||||
"scheduler_config":
|
||||
SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy.
|
||||
GUARANTEED_NO_EVICT)
|
||||
}
|
||||
plugin_config_dict = {
|
||||
"_dtype": 'float16',
|
||||
"nccl_plugin": None,
|
||||
}
|
||||
plugin_config = PluginConfig.from_dict(plugin_config_dict)
|
||||
build_config = BuildConfig(max_input_len=1024,
|
||||
lora_config=LoraConfig(lora_ckpt_source='hf'),
|
||||
auto_parallel_config=AutoParallelConfig(
|
||||
world_size=1,
|
||||
same_buffer_io={},
|
||||
debug_outputs=[]),
|
||||
plugin_config=plugin_config)
|
||||
extra_llm_args_dict = {
|
||||
"build_config": build_config.to_dict(),
|
||||
}
|
||||
|
||||
llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
|
||||
extra_llm_args_dict,
|
||||
"build_config")
|
||||
initialized_llm_args = TrtLlmArgs(**llm_api_args_dict)
|
||||
|
||||
def check_nested_dict_equality(dict1, dict2, path=""):
|
||||
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
|
||||
if dict1 != dict2:
|
||||
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
|
||||
return True
|
||||
if dict1.keys() != dict2.keys():
|
||||
raise ValueError(f"Different keys at {path}:")
|
||||
for key in dict1:
|
||||
new_path = f"{path}.{key}" if path else key
|
||||
if not check_nested_dict_equality(dict1[key], dict2[key], new_path):
|
||||
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
|
||||
return True
|
||||
|
||||
build_config_dict1 = build_config.to_dict()
|
||||
build_config_dict2 = initialized_llm_args.build_config.to_dict()
|
||||
check_nested_dict_equality(build_config_dict1, build_config_dict2)
|
||||
|
||||
|
||||
class TestTorchLlmArgsCudaGraphSettings:
|
||||
|
||||
def test_cuda_graph_batch_sizes_case_0(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user