mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-12 05:53:33 +08:00
chore: enhance yaml loading arbitrary options in LlmArgs (#5610)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
3e75320fe8
commit
2d69b55fe8
@ -1949,15 +1949,9 @@ def update_llm_args_with_extra_dict(
|
||||
"quant_config": QuantConfig,
|
||||
"calib_config": CalibConfig,
|
||||
"build_config": BuildConfig,
|
||||
"kv_cache_config": KvCacheConfig,
|
||||
"decoding_config": DecodingConfig,
|
||||
"enable_build_cache": BuildCacheConfig,
|
||||
"peft_cache_config": PeftCacheConfig,
|
||||
"scheduler_config": SchedulerConfig,
|
||||
"speculative_config": DecodingBaseConfig,
|
||||
"batching_type": BatchingType,
|
||||
"extended_runtime_perf_knob_config": ExtendedRuntimePerfKnobConfig,
|
||||
"cache_transceiver_config": CacheTransceiverConfig,
|
||||
"lora_config": LoraConfig,
|
||||
}
|
||||
for field_name, field_type in field_mapping.items():
|
||||
|
||||
@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig():
|
||||
assert pybind_config.max_verification_set_size == 4
|
||||
|
||||
|
||||
def test_update_llm_args_with_extra_dict_with_speculative_config():
|
||||
yaml_content = """
|
||||
speculative_config:
|
||||
decoding_type: Lookahead
|
||||
max_window_size: 4
|
||||
max_ngram_size: 3
|
||||
verification_set_size: 4
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(yaml_content.encode('utf-8'))
|
||||
f.flush()
|
||||
f.seek(0)
|
||||
dict_content = yaml.safe_load(f)
|
||||
class TestYaml:
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.speculative_config.max_window_size == 4
|
||||
assert llm_args.speculative_config.max_ngram_size == 3
|
||||
assert llm_args.speculative_config.max_verification_set_size == 4
|
||||
def _yaml_to_dict(self, yaml_content: str) -> dict:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
f.write(yaml_content.encode('utf-8'))
|
||||
f.flush()
|
||||
f.seek(0)
|
||||
dict_content = yaml.safe_load(f)
|
||||
return dict_content
|
||||
|
||||
def test_update_llm_args_with_extra_dict_with_speculative_config(self):
|
||||
yaml_content = """
|
||||
speculative_config:
|
||||
decoding_type: Lookahead
|
||||
max_window_size: 4
|
||||
max_ngram_size: 3
|
||||
verification_set_size: 4
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.speculative_config.max_window_size == 4
|
||||
assert llm_args.speculative_config.max_ngram_size == 3
|
||||
assert llm_args.speculative_config.max_verification_set_size == 4
|
||||
|
||||
def test_llm_args_with_invalid_yaml(self):
|
||||
yaml_content = """
|
||||
pytorch_backend_config: # this is deprecated
|
||||
max_num_tokens: 1
|
||||
max_seq_len: 1
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
with pytest.raises(ValueError):
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
|
||||
def test_llm_args_with_build_config(self):
|
||||
# build_config isn't a Pydantic
|
||||
yaml_content = """
|
||||
build_config:
|
||||
max_beam_width: 4
|
||||
max_batch_size: 8
|
||||
max_num_tokens: 256
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.build_config.max_beam_width == 4
|
||||
assert llm_args.build_config.max_batch_size == 8
|
||||
assert llm_args.build_config.max_num_tokens == 256
|
||||
|
||||
def test_llm_args_with_kvcache_config(self):
|
||||
yaml_content = """
|
||||
kv_cache_config:
|
||||
enable_block_reuse: True
|
||||
max_tokens: 1024
|
||||
max_attention_window: [1024, 1024, 1024]
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.kv_cache_config.enable_block_reuse == True
|
||||
assert llm_args.kv_cache_config.max_tokens == 1024
|
||||
assert llm_args.kv_cache_config.max_attention_window == [
|
||||
1024, 1024, 1024
|
||||
]
|
||||
|
||||
def test_llm_args_with_pydantic_options(self):
|
||||
yaml_content = """
|
||||
max_batch_size: 16
|
||||
max_num_tokens: 256
|
||||
max_seq_len: 128
|
||||
"""
|
||||
dict_content = self._yaml_to_dict(yaml_content)
|
||||
|
||||
llm_args = TrtLlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = TrtLlmArgs(**llm_args_dict)
|
||||
assert llm_args.max_batch_size == 16
|
||||
assert llm_args.max_num_tokens == 256
|
||||
assert llm_args.max_seq_len == 128
|
||||
|
||||
|
||||
def check_defaults(py_config_cls, pybind_config_cls):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user