diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index b2b22aa1e0..19142ab13b 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1642,9 +1642,12 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int], streaming=False, backend=None): LLM_CLASS = LLM + llm_extra_kwargs = {} if backend == "pytorch": from tensorrt_llm._torch import LLM as LLM_torch LLM_CLASS = LLM_torch + else: + llm_extra_kwargs["fast_build"] = True llm = LLM_CLASS( llama_model_path, @@ -1652,6 +1655,7 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int], build_config=BuildConfig(gather_context_logits=True), tensor_parallel_size=tp_size, gather_generation_logits=True, + **llm_extra_kwargs, ) prompts = ["A B C D E F G H I J K"] @@ -1950,7 +1954,6 @@ def test_llm_get_queued_stats(): llm = LLM_CLASS(model=llama_model_path, kv_cache_config=global_kvcache_config, tensor_parallel_size=tp_size, - fast_build=True, max_batch_size=1, **llm_args_extra)