chore: enhance yaml loading arbitrary options in LlmArgs (#5610)

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-07-02 14:21:37 +08:00 committed by GitHub
parent 3e75320fe8
commit 2d69b55fe8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 26 deletions

View File

@ -1949,15 +1949,9 @@ def update_llm_args_with_extra_dict(
"quant_config": QuantConfig,
"calib_config": CalibConfig,
"build_config": BuildConfig,
"kv_cache_config": KvCacheConfig,
"decoding_config": DecodingConfig,
"enable_build_cache": BuildCacheConfig,
"peft_cache_config": PeftCacheConfig,
"scheduler_config": SchedulerConfig,
"speculative_config": DecodingBaseConfig,
"batching_type": BatchingType,
"extended_runtime_perf_knob_config": ExtendedRuntimePerfKnobConfig,
"cache_transceiver_config": CacheTransceiverConfig,
"lora_config": LoraConfig,
}
for field_name, field_type in field_mapping.items():

View File

@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig():
assert pybind_config.max_verification_set_size == 4
def test_update_llm_args_with_extra_dict_with_speculative_config():
yaml_content = """
speculative_config:
decoding_type: Lookahead
max_window_size: 4
max_ngram_size: 3
verification_set_size: 4
"""
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(yaml_content.encode('utf-8'))
f.flush()
f.seek(0)
dict_content = yaml.safe_load(f)
class TestYaml:
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4
assert llm_args.speculative_config.max_ngram_size == 3
assert llm_args.speculative_config.max_verification_set_size == 4
def _yaml_to_dict(self, yaml_content: str) -> dict:
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(yaml_content.encode('utf-8'))
f.flush()
f.seek(0)
dict_content = yaml.safe_load(f)
return dict_content
def test_update_llm_args_with_extra_dict_with_speculative_config(self):
yaml_content = """
speculative_config:
decoding_type: Lookahead
max_window_size: 4
max_ngram_size: 3
verification_set_size: 4
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.speculative_config.max_window_size == 4
assert llm_args.speculative_config.max_ngram_size == 3
assert llm_args.speculative_config.max_verification_set_size == 4
def test_llm_args_with_invalid_yaml(self):
yaml_content = """
pytorch_backend_config: # this is deprecated
max_num_tokens: 1
max_seq_len: 1
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
with pytest.raises(ValueError):
llm_args = TrtLlmArgs(**llm_args_dict)
def test_llm_args_with_build_config(self):
# build_config isn't a Pydantic
yaml_content = """
build_config:
max_beam_width: 4
max_batch_size: 8
max_num_tokens: 256
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.build_config.max_beam_width == 4
assert llm_args.build_config.max_batch_size == 8
assert llm_args.build_config.max_num_tokens == 256
def test_llm_args_with_kvcache_config(self):
yaml_content = """
kv_cache_config:
enable_block_reuse: True
max_tokens: 1024
max_attention_window: [1024, 1024, 1024]
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.kv_cache_config.enable_block_reuse == True
assert llm_args.kv_cache_config.max_tokens == 1024
assert llm_args.kv_cache_config.max_attention_window == [
1024, 1024, 1024
]
def test_llm_args_with_pydantic_options(self):
yaml_content = """
max_batch_size: 16
max_num_tokens: 256
max_seq_len: 128
"""
dict_content = self._yaml_to_dict(yaml_content)
llm_args = TrtLlmArgs(model=llama_model_path)
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
dict_content)
llm_args = TrtLlmArgs(**llm_args_dict)
assert llm_args.max_batch_size == 16
assert llm_args.max_num_tokens == 256
assert llm_args.max_seq_len == 128
def check_defaults(py_config_cls, pybind_config_cls):