mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Allow YAML config overwriting CLI args for trtllm-eval (#10296)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
f3f02315df
commit
13ffe52ad0
@ -143,27 +143,30 @@ def main(ctx, model: str, tokenizer: Optional[str],
|
||||
"kv_cache_config": kv_cache_config,
|
||||
}
|
||||
|
||||
if backend == 'pytorch':
|
||||
llm_cls = PyTorchLLM
|
||||
llm_args.update(max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_beam_width=max_beam_width,
|
||||
max_seq_len=max_seq_len)
|
||||
elif backend == 'tensorrt':
|
||||
llm_cls = LLM
|
||||
build_config = BuildConfig(max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_beam_width=max_beam_width,
|
||||
max_seq_len=max_seq_len)
|
||||
llm_args.update(build_config=build_config)
|
||||
else:
|
||||
raise click.BadParameter(
|
||||
f"{backend} is not a known backend, check help for available options.",
|
||||
param_hint="backend")
|
||||
|
||||
if extra_llm_api_options is not None:
|
||||
llm_args = update_llm_args_with_extra_options(llm_args,
|
||||
extra_llm_api_options)
|
||||
|
||||
profiler.start("trtllm init")
|
||||
if backend == 'pytorch':
|
||||
llm = PyTorchLLM(**llm_args,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_beam_width=max_beam_width,
|
||||
max_seq_len=max_seq_len)
|
||||
elif backend == 'tensorrt':
|
||||
build_config = BuildConfig(max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_beam_width=max_beam_width,
|
||||
max_seq_len=max_seq_len)
|
||||
llm = LLM(**llm_args, build_config=build_config)
|
||||
else:
|
||||
raise click.BadParameter(
|
||||
f"{backend} is not a known backend, check help for available options.",
|
||||
param_hint="backend")
|
||||
llm = llm_cls(**llm_args)
|
||||
profiler.stop("trtllm init")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm init")
|
||||
logger.info(f"TRTLLM initialization time: {elapsed_time:.3f} seconds.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user