[https://nvbugs/5768068][chore] improve disagg acc tests (#10833)

Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
Bo Deng 2026-01-22 22:45:35 +08:00 committed by GitHub
parent 5e34112b27
commit a218cf02fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,7 +10,7 @@ import tempfile
import time
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import openai
import pytest
@ -26,8 +26,8 @@ from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_no_hopper, skip_pre_blackwell, skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
LlmapiAccuracyTestHarness, get_accuracy_task)
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)
class Result(GenerationResultBase):
@ -48,8 +48,12 @@ class Result(GenerationResultBase):
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
# Timeout for the entire test
DEFAULT_TEST_TIMEOUT = 3600
DEFAULT_SERVER_WAITING_TIMEOUT = 3600
# Timeout for the server waiting
DEFAULT_SERVER_WAITING_TIMEOUT = 2100
# Timeout for the accuracy evaluation
DEFAULT_ACC_EVALUATION_TIMEOUT = 1500
@functools.lru_cache(maxsize=1)
@ -104,6 +108,31 @@ class MyThreadPoolExecutor(ThreadPoolExecutor):
return False
def run_accuracy_test(llm: "DuckLLM",
model_name: str,
test_sets: List[Union[str, type]] = ["MMLU", "GSM8K"],
extra_evaluator_kwargs: Optional[Dict[Union[str, type],
Dict[str,
Any]]] = None,
timeout: int = DEFAULT_ACC_EVALUATION_TIMEOUT):
start_time = time.time()
for test_set in test_sets:
if isinstance(test_set, str):
test_set = get_accuracy_task(test_set)
task = test_set(model_name)
if extra_evaluator_kwargs is not None:
kwargs = extra_evaluator_kwargs.get(test_set, {})
else:
kwargs = {}
task.evaluate(llm, extra_evaluator_kwargs=kwargs)
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
pytest.fail(
f"The accuracy evaluation took too long to complete. Expected: {timeout}s, Actual: {elapsed_time:.2f}s"
)
@contextlib.contextmanager
def launch_disaggregated_llm(
disaggregated_server_config: Dict[str, Any],
@ -301,6 +330,7 @@ def launch_disaggregated_llm(
server_processes,
):
start_time = time.time()
server_is_ready = False
while time.time() - start_time < server_waiting_timeout:
time.sleep(5)
for process in itertools.chain(ctx_processes, gen_processes,
@ -313,9 +343,14 @@ def launch_disaggregated_llm(
print("Checking health endpoint")
response = requests.get(f"http://localhost:{serve_port}/health")
if response.status_code == 200:
server_is_ready = True
break
except requests.exceptions.ConnectionError:
continue
if not server_is_ready:
pytest.fail(
f"Server is not ready after {server_waiting_timeout} seconds. Please check the logs for more details."
)
client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:{serve_port}/v1",
@ -479,9 +514,7 @@ def run_parallel_test(model_name: str,
model_path,
ctx_model=ctx_model,
gen_model=gen_model) as llm:
for test_set in test_sets:
task = test_set(model_name)
task.evaluate(llm)
run_accuracy_test(llm, model_name, test_sets)
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
@ -526,10 +559,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@pytest.mark.skip_less_device(2)
def test_ngram(self):
@ -576,8 +606,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
@pytest.mark.skip_less_device(2)
@skip_pre_hopper
@ -636,8 +665,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(32000)
@ -673,8 +701,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(48000)
@ -730,8 +757,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
@ -792,10 +818,7 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
@ -836,10 +859,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@pytest.mark.skip_less_device(8)
@parametrize_with_ids("overlap_scheduler", [True, False])
@ -877,10 +897,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@skip_pre_blackwell
@pytest.mark.skip_less_device(8)
@ -961,10 +978,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
@ -1017,8 +1031,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
@ -1071,10 +1084,7 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@skip_pre_blackwell
@ -1136,9 +1146,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
model_name = "GPT-OSS/120B-MXFP4"
task = GSM8K(model_name)
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
run_accuracy_test(
llm,
model_name,
test_sets=["GSM8K"],
extra_evaluator_kwargs={"GSM8K": self.extra_evaluator_kwargs})
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
@ -1205,10 +1217,7 @@ class TestDeepSeekV32Exp(LlmapiAccuracyTestHarness):
gen_server_config,
self.MODEL_PATH,
max_workers=128) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
@ -1247,8 +1256,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
@skip_pre_hopper
@pytest.mark.skip_less_device(2)
@ -1284,10 +1292,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
def test_chunked_prefill(self):
# bs=1 will stabilize the result, but the test will be much slower
@ -1325,10 +1330,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@skip_pre_blackwell
@ -1364,7 +1366,6 @@ class TestKimiK2(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_device_memory(200000)
def test_nvfp4(self):
model_name = "moonshotai/Kimi-K2-Thinking"
model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4"
ctx_server_config = {
"max_batch_size": 16,
@ -1408,5 +1409,4 @@ class TestKimiK2(LlmapiAccuracyTestHarness):
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
model_path) as llm:
task = GSM8K(model_name)
task.evaluate(llm)
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])