mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Only pass fast_build=true to non-pytorch backend (#4920)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
parent
9ceef983c0
commit
ddbaa5ef80
@ -1642,9 +1642,12 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int],
|
||||
streaming=False,
|
||||
backend=None):
|
||||
LLM_CLASS = LLM
|
||||
llm_extra_kwargs = {}
|
||||
if backend == "pytorch":
|
||||
from tensorrt_llm._torch import LLM as LLM_torch
|
||||
LLM_CLASS = LLM_torch
|
||||
else:
|
||||
llm_extra_kwargs["fast_build"] = True
|
||||
|
||||
llm = LLM_CLASS(
|
||||
llama_model_path,
|
||||
@ -1652,6 +1655,7 @@ def llm_return_logprobs_test_harness(prompt_logprobs: Optional[int],
|
||||
build_config=BuildConfig(gather_context_logits=True),
|
||||
tensor_parallel_size=tp_size,
|
||||
gather_generation_logits=True,
|
||||
**llm_extra_kwargs,
|
||||
)
|
||||
|
||||
prompts = ["A B C D E F G H I J K"]
|
||||
@ -1950,7 +1954,6 @@ def test_llm_get_queued_stats():
|
||||
llm = LLM_CLASS(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
tensor_parallel_size=tp_size,
|
||||
fast_build=True,
|
||||
max_batch_size=1,
|
||||
**llm_args_extra)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user