mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] fix CUDA graph config for test_llm_api_pytorch.py. (#6826)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
3d169bfdad
commit
cf00003f3d
@ -503,7 +503,8 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
|
||||
max_seq_len=8192,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
cuda_graph_config=CudaGraphConfig()
|
||||
if cuda_graph else None) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -526,7 +527,8 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=256,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
cuda_graph_config=CudaGraphConfig()
|
||||
if cuda_graph else None) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -646,7 +648,8 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=256,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
cuda_graph_config=CudaGraphConfig()
|
||||
if cuda_graph else None) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -668,7 +671,8 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
max_seq_len=22000,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=256,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
cuda_graph_config=CudaGraphConfig()
|
||||
if cuda_graph else None) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user