mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[nvbug/5341178][fix] Fix OOM in Llama 4 accuracy test (#5735)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
2aacdba1e4
commit
2b66fe8fbd
@ -389,11 +389,14 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
|
||||
(8, 1, 8)],
|
||||
ids=["tp8", "tp8ep4", "tp8ep8"])
|
||||
def test_auto_dtype(self, cuda_graph, tp_size, pp_size, ep_size):
|
||||
with LLM(self.MODEL_PATH,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
with LLM(
|
||||
self.MODEL_PATH,
|
||||
tensor_parallel_size=tp_size,
|
||||
# Keep this low to avoid warmup OOM in CI
|
||||
max_seq_len=8192,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
@ -411,11 +414,14 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
(8, 1, 8)],
|
||||
ids=["tp8", "tp8ep4", "tp8ep8"])
|
||||
def test_auto_dtype(self, cuda_graph, tp_size, pp_size, ep_size):
|
||||
with LLM(self.MODEL_PATH,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
with LLM(
|
||||
self.MODEL_PATH,
|
||||
tensor_parallel_size=tp_size,
|
||||
# Keep this low to avoid warmup OOM in CI
|
||||
max_seq_len=8192,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
use_cuda_graph=cuda_graph) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user