diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index a98bdfab34..f05e327c9e 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -10,7 +10,7 @@ import tempfile import time from collections import namedtuple from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import openai import pytest @@ -26,8 +26,8 @@ from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids, skip_no_hopper, skip_pre_blackwell, skip_pre_hopper) from ..trt_test_alternative import popen -from .accuracy_core import (GSM8K, MMLU, JsonModeEval, - LlmapiAccuracyTestHarness, get_accuracy_task) +from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness, + get_accuracy_task) class Result(GenerationResultBase): @@ -48,8 +48,12 @@ class Result(GenerationResultBase): DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async']) +# Timeout for the entire test DEFAULT_TEST_TIMEOUT = 3600 -DEFAULT_SERVER_WAITING_TIMEOUT = 3600 +# Timeout for the server waiting +DEFAULT_SERVER_WAITING_TIMEOUT = 2100 +# Timeout for the accuracy evaluation +DEFAULT_ACC_EVALUATION_TIMEOUT = 1500 @functools.lru_cache(maxsize=1) @@ -104,6 +108,31 @@ class MyThreadPoolExecutor(ThreadPoolExecutor): return False +def run_accuracy_test(llm: "DuckLLM", + model_name: str, + test_sets: List[Union[str, type]] = ["MMLU", "GSM8K"], + extra_evaluator_kwargs: Optional[Dict[Union[str, type], + Dict[str, + Any]]] = None, + timeout: int = DEFAULT_ACC_EVALUATION_TIMEOUT): + start_time = time.time() + for test_set in test_sets: + if isinstance(test_set, str): + test_set = get_accuracy_task(test_set) + task = test_set(model_name) + + if extra_evaluator_kwargs is not None: + kwargs = extra_evaluator_kwargs.get(test_set, {}) + else: + kwargs = {} + task.evaluate(llm, extra_evaluator_kwargs=kwargs) + elapsed_time = time.time() - start_time + if elapsed_time > timeout: + pytest.fail( + f"The accuracy evaluation took too long to complete. Expected: {timeout}s, Actual: {elapsed_time:.2f}s" + ) + + @contextlib.contextmanager def launch_disaggregated_llm( disaggregated_server_config: Dict[str, Any], @@ -301,6 +330,7 @@ def launch_disaggregated_llm( server_processes, ): start_time = time.time() + server_is_ready = False while time.time() - start_time < server_waiting_timeout: time.sleep(5) for process in itertools.chain(ctx_processes, gen_processes, @@ -313,9 +343,14 @@ def launch_disaggregated_llm( print("Checking health endpoint") response = requests.get(f"http://localhost:{serve_port}/health") if response.status_code == 200: + server_is_ready = True break except requests.exceptions.ConnectionError: continue + if not server_is_ready: + pytest.fail( + f"Server is not ready after {server_waiting_timeout} seconds. Please check the logs for more details." + ) client = openai.OpenAI(api_key="1234567890", base_url=f"http://localhost:{serve_port}/v1", @@ -479,9 +514,7 @@ def run_parallel_test(model_name: str, model_path, ctx_model=ctx_model, gen_model=gen_model) as llm: - for test_set in test_sets: - task = test_set(model_name) - task.evaluate(llm) + run_accuracy_test(llm, model_name, test_sets) @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) @@ -526,10 +559,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @pytest.mark.skip_less_device(2) def test_ngram(self): @@ -576,8 +606,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) @pytest.mark.skip_less_device(2) @skip_pre_hopper @@ -636,8 +665,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(32000) @@ -673,8 +701,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = JsonModeEval(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(48000) @@ -730,8 +757,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = JsonModeEval(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) @pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)], ids=["tp1pp2", "tp2pp1", "tp2pp2"]) @@ -792,10 +818,7 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): gen_server_config, self.MODEL_PATH, tensor_parallel_size=4) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) @@ -836,10 +859,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @pytest.mark.skip_less_device(8) @parametrize_with_ids("overlap_scheduler", [True, False]) @@ -877,10 +897,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): gen_server_config, self.MODEL_PATH, tensor_parallel_size=4) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @skip_pre_blackwell @pytest.mark.skip_less_device(8) @@ -961,10 +978,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(60000) @@ -1017,8 +1031,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = JsonModeEval(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) @@ -1071,10 +1084,7 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @skip_pre_blackwell @@ -1136,9 +1146,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: model_name = "GPT-OSS/120B-MXFP4" - task = GSM8K(model_name) - task.evaluate(llm, - extra_evaluator_kwargs=self.extra_evaluator_kwargs) + run_accuracy_test( + llm, + model_name, + test_sets=["GSM8K"], + extra_evaluator_kwargs={"GSM8K": self.extra_evaluator_kwargs}) @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) @@ -1205,10 +1217,7 @@ class TestDeepSeekV32Exp(LlmapiAccuracyTestHarness): gen_server_config, self.MODEL_PATH, max_workers=128) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) @@ -1247,8 +1256,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) @skip_pre_hopper @pytest.mark.skip_less_device(2) @@ -1284,10 +1292,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) def test_chunked_prefill(self): # bs=1 will stabilize the result, but the test will be much slower @@ -1325,10 +1330,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, self.MODEL_PATH) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) @skip_pre_blackwell @@ -1364,7 +1366,6 @@ class TestKimiK2(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(8) @pytest.mark.skip_less_device_memory(200000) def test_nvfp4(self): - model_name = "moonshotai/Kimi-K2-Thinking" model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4" ctx_server_config = { "max_batch_size": 16, @@ -1408,5 +1409,4 @@ class TestKimiK2(LlmapiAccuracyTestHarness): with launch_disaggregated_llm(disaggregated_server_config, ctx_server_config, gen_server_config, model_path) as llm: - task = GSM8K(model_name) - task.evaluate(llm) + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])