From 9c1b75e978e5b603208f447973ef2563a9f66961 Mon Sep 17 00:00:00 2001 From: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Date: Mon, 22 Sep 2025 15:12:43 +0800 Subject: [PATCH] [TRTLLM-7070][feat] add gpt-oss chunked prefill tests (#7779) Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com> --- .../defs/accuracy/test_llm_api_pytorch.py | 36 +++++++++++++++++++ .../test_lists/qa/llm_function_core.txt | 4 +++ .../qa/llm_function_core_sanity.txt | 4 +++ .../test_lists/qa/llm_function_nim.txt | 4 +++ .../apps/_test_trtllm_serve_benchmark.py | 16 +++++++-- 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ef02fcd9c6..2e295473bf 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -3261,6 +3261,42 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): task.evaluate(llm, extra_evaluator_kwargs=self.extra_evaluator_kwargs) + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) + @pytest.mark.parametrize( + "moe_backend", + ["CUTLASS", + pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"], + ids=["cutlass", "trtllm", "triton"]) + def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker): + if moe_backend == "TRITON": + if not IS_TRITON_KERNELS_AVAILABLE: + pytest.skip("Triton kernels are not available") + + pytorch_config = dict(disable_overlap_scheduler=True, + cuda_graph_config=CudaGraphConfig()) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, + dtype=kv_cache_dtype) + + model_name = "GPT-OSS/MXFP4" + with LLM(self.MODEL_PATH, + tensor_parallel_size=4, + pipeline_parallel_size=1, + moe_expert_parallel_size=1, + kv_cache_config=kv_cache_config, + max_seq_len=8192, + max_num_tokens=512, + enable_chunked_prefill=True, + enable_attention_dp=False, + moe_config=MoeConfig(backend=moe_backend), + **pytorch_config) as llm: + mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192) + mocker.patch.dict(GSM8K.EVALUATE_KWARGS, + {"scores_filter": "exact_match,flexible-extract"}) + task = GSM8K(model_name) + task.evaluate(llm, + extra_evaluator_kwargs=self.extra_evaluator_kwargs) + class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index e445478bbf..6c3f9200cd 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -547,6 +547,10 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index c398961276..b04dc167f0 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -83,6 +83,10 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency] diff --git a/tests/integration/test_lists/qa/llm_function_nim.txt b/tests/integration/test_lists/qa/llm_function_nim.txt index a969bd3c08..db14f531f1 100644 --- a/tests/integration/test_lists/qa/llm_function_nim.txt +++ b/tests/integration/test_lists/qa/llm_function_nim.txt @@ -153,6 +153,10 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestLlama3_3NemotronSuper49Bv1::test_fp8_prequantized_tp2 accuracy/test_llm_api_pytorch.py::TestLlama3_1NemotronNano8Bv1::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_1NemotronNano8Bv1::test_fp8_prequantized diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py index 908757b7c9..db7c6dfc4f 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py @@ -56,8 +56,20 @@ def test_trtllm_serve_benchmark(server: RemoteOpenAIServer, benchmark_root: str, client_script = os.path.join(benchmark_root, "benchmark_serving.py") dataset = dataset_path("sharegpt") benchmark_cmd = [ - "python3", client_script, "--dataset-name", "sharegpt", "--model", - model_name, "--dataset-path", dataset, "--tokenizer", model_path + "python3", + client_script, + "--dataset-name", + "sharegpt", + "--model", + model_name, + "--dataset-path", + dataset, + "--tokenizer", + model_path, + "--temperature", + "1.0", + "--top-p", + "1.0", ] # CalledProcessError will be raised if any errors occur