mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7245][feat] add test_multi_nodes_eval tests (#7108)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
This commit is contained in:
parent
d94cc3fa3c
commit
b8b2bd4a0a
@ -956,3 +956,23 @@ def get_dummy_spec_decoding_heads(hf_model_dir,
|
||||
export_hf_checkpoint(model,
|
||||
dtype=model.config.torch_dtype,
|
||||
export_dir=os.path.join(save_dir, 'fp8'))
|
||||
|
||||
|
||||
def get_mmlu_accuracy(output):
|
||||
mmlu_line = None
|
||||
for line in output.split('\n'):
|
||||
if "MMLU weighted average accuracy:" in line:
|
||||
mmlu_line = line
|
||||
break
|
||||
|
||||
if mmlu_line is None:
|
||||
raise Exception(
|
||||
f"Could not find 'MMLU weighted average accuracy:' in output. Full output:\n{output}"
|
||||
)
|
||||
|
||||
mmlu_accuracy = float(
|
||||
mmlu_line.split("MMLU weighted average accuracy: ")[1].split(" (")[0])
|
||||
|
||||
print(f"MMLU weighted average accuracy is: {mmlu_accuracy}")
|
||||
|
||||
return mmlu_accuracy
|
||||
|
||||
@ -28,8 +28,9 @@ from defs.common import convert_weights
|
||||
from defs.trt_test_alternative import (check_call, check_call_negative_test,
|
||||
check_output)
|
||||
|
||||
from .common import (PluginOptions, convert_weights, prune_checkpoint,
|
||||
quantize_data, refit_model, venv_check_call)
|
||||
from .common import (PluginOptions, convert_weights, get_mmlu_accuracy,
|
||||
prune_checkpoint, quantize_data, refit_model,
|
||||
venv_check_call)
|
||||
from .conftest import (llm_models_root, skip_no_sm120, skip_nvlink_inactive,
|
||||
skip_post_blackwell, skip_pre_blackwell, skip_pre_hopper,
|
||||
tests_path, unittest_path)
|
||||
@ -42,6 +43,7 @@ if TEST_MEM_USAGE:
|
||||
os.environ['TLLM_LOG_LEVEL'] = 'INFO'
|
||||
|
||||
_MEM_FRACTION_50 = 0.5
|
||||
_MEM_FRACTION_80 = 0.8
|
||||
_MEM_FRACTION_95 = 0.95
|
||||
|
||||
|
||||
@ -2677,4 +2679,43 @@ def test_ptp_quickstart_advanced_llama_multi_nodes(llm_root, llm_venv,
|
||||
check_call(" ".join(run_cmd), shell=True, env=llm_venv._new_env)
|
||||
|
||||
|
||||
# End of Pivot-To-Python examples
|
||||
@pytest.mark.timeout(5400)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("eval_task", ["mmlu"])
|
||||
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(16, 1, 8), (8, 2, 8)],
|
||||
ids=["tp16", "tp8pp2"])
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
pytest.param('llama-3.3-models/Llama-3.3-70B-Instruct',
|
||||
marks=skip_pre_hopper),
|
||||
pytest.param('llama4-models/Llama-4-Maverick-17B-128E-Instruct',
|
||||
marks=skip_pre_hopper),
|
||||
pytest.param('llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8',
|
||||
marks=skip_pre_hopper),
|
||||
pytest.param('Qwen3/Qwen3-235B-A22B', marks=skip_pre_hopper),
|
||||
pytest.param('Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf',
|
||||
marks=skip_pre_blackwell),
|
||||
pytest.param('DeepSeek-R1/DeepSeek-R1-0528-FP4', marks=skip_pre_blackwell),
|
||||
])
|
||||
def test_multi_nodes_eval(llm_venv, model_path, tp_size, pp_size, ep_size,
|
||||
eval_task):
|
||||
if "Llama-4" in model_path and tp_size == 16:
|
||||
pytest.skip("Llama-4 with tp16 is not supported")
|
||||
|
||||
mmlu_threshold = 81.5
|
||||
run_cmd = [
|
||||
"trtllm-llmapi-launch",
|
||||
"trtllm-eval",
|
||||
f"--model={llm_models_root()}/{model_path}",
|
||||
f"--ep_size={ep_size}",
|
||||
f"--tp_size={tp_size}",
|
||||
f"--pp_size={pp_size}",
|
||||
f"--kv_cache_free_gpu_memory_fraction={_MEM_FRACTION_80}",
|
||||
"--max_batch_size=32",
|
||||
eval_task,
|
||||
]
|
||||
output = check_output(" ".join(run_cmd), shell=True, env=llm_venv._new_env)
|
||||
|
||||
if os.environ.get("SLURM_PROCID", '0') == '0':
|
||||
mmlu_accuracy = get_mmlu_accuracy(output)
|
||||
assert mmlu_accuracy > mmlu_threshold, f"MMLU accuracy {mmlu_accuracy} is less than threshold {mmlu_threshold}"
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-build]
|
||||
examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-8b-disable_fp8-tp16pp1-infer]
|
||||
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-build]
|
||||
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-infer]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-V3]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_deepseek_multi_nodes[DeepSeek-R1/DeepSeek-R1-0528-FP4]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama-3.3-models/Llama-3.3-70B-Instruct]
|
||||
test_e2e.py::test_ptp_quickstart_advanced_llama_multi_nodes[llama4-models/Llama-4-Maverick-17B-128E-Instruct]
|
||||
test_e2e.py::test_openai_multinodes_chat_tp16pp1
|
||||
test_e2e.py::test_multi_nodes_eval[llama-3.3-models/Llama-3.3-70B-Instruct-tp16-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/Llama-4-Maverick-17B-128E-Instruct-tp8pp2-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[Qwen3/Qwen3-235B-A22B-tp16-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf-tp16-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[DeepSeek-R1/DeepSeek-R1-0528-FP4-tp16-mmlu]
|
||||
|
||||
@ -324,3 +324,4 @@ accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_mo
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user