mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5768068][chore] improve disagg acc tests (#10833)
Signed-off-by: Bo Deng <deemod@nvidia.com>
This commit is contained in:
parent
5e34112b27
commit
a218cf02fd
@ -10,7 +10,7 @@ import tempfile
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -26,8 +26,8 @@ from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
|
||||
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
|
||||
skip_no_hopper, skip_pre_blackwell, skip_pre_hopper)
|
||||
from ..trt_test_alternative import popen
|
||||
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
|
||||
LlmapiAccuracyTestHarness, get_accuracy_task)
|
||||
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
|
||||
get_accuracy_task)
|
||||
|
||||
|
||||
class Result(GenerationResultBase):
|
||||
@ -48,8 +48,12 @@ class Result(GenerationResultBase):
|
||||
|
||||
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
|
||||
|
||||
# Timeout for the entire test
|
||||
DEFAULT_TEST_TIMEOUT = 3600
|
||||
DEFAULT_SERVER_WAITING_TIMEOUT = 3600
|
||||
# Timeout for the server waiting
|
||||
DEFAULT_SERVER_WAITING_TIMEOUT = 2100
|
||||
# Timeout for the accuracy evaluation
|
||||
DEFAULT_ACC_EVALUATION_TIMEOUT = 1500
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
@ -104,6 +108,31 @@ class MyThreadPoolExecutor(ThreadPoolExecutor):
|
||||
return False
|
||||
|
||||
|
||||
def run_accuracy_test(llm: "DuckLLM",
|
||||
model_name: str,
|
||||
test_sets: List[Union[str, type]] = ["MMLU", "GSM8K"],
|
||||
extra_evaluator_kwargs: Optional[Dict[Union[str, type],
|
||||
Dict[str,
|
||||
Any]]] = None,
|
||||
timeout: int = DEFAULT_ACC_EVALUATION_TIMEOUT):
|
||||
start_time = time.time()
|
||||
for test_set in test_sets:
|
||||
if isinstance(test_set, str):
|
||||
test_set = get_accuracy_task(test_set)
|
||||
task = test_set(model_name)
|
||||
|
||||
if extra_evaluator_kwargs is not None:
|
||||
kwargs = extra_evaluator_kwargs.get(test_set, {})
|
||||
else:
|
||||
kwargs = {}
|
||||
task.evaluate(llm, extra_evaluator_kwargs=kwargs)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
pytest.fail(
|
||||
f"The accuracy evaluation took too long to complete. Expected: {timeout}s, Actual: {elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_disaggregated_llm(
|
||||
disaggregated_server_config: Dict[str, Any],
|
||||
@ -301,6 +330,7 @@ def launch_disaggregated_llm(
|
||||
server_processes,
|
||||
):
|
||||
start_time = time.time()
|
||||
server_is_ready = False
|
||||
while time.time() - start_time < server_waiting_timeout:
|
||||
time.sleep(5)
|
||||
for process in itertools.chain(ctx_processes, gen_processes,
|
||||
@ -313,9 +343,14 @@ def launch_disaggregated_llm(
|
||||
print("Checking health endpoint")
|
||||
response = requests.get(f"http://localhost:{serve_port}/health")
|
||||
if response.status_code == 200:
|
||||
server_is_ready = True
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
continue
|
||||
if not server_is_ready:
|
||||
pytest.fail(
|
||||
f"Server is not ready after {server_waiting_timeout} seconds. Please check the logs for more details."
|
||||
)
|
||||
|
||||
client = openai.OpenAI(api_key="1234567890",
|
||||
base_url=f"http://localhost:{serve_port}/v1",
|
||||
@ -479,9 +514,7 @@ def run_parallel_test(model_name: str,
|
||||
model_path,
|
||||
ctx_model=ctx_model,
|
||||
gen_model=gen_model) as llm:
|
||||
for test_set in test_sets:
|
||||
task = test_set(model_name)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, model_name, test_sets)
|
||||
|
||||
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
@ -526,10 +559,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_ngram(self):
|
||||
@ -576,8 +606,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@skip_pre_hopper
|
||||
@ -636,8 +665,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@ -673,8 +701,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(48000)
|
||||
@ -730,8 +757,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
|
||||
|
||||
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
|
||||
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
|
||||
@ -792,10 +818,7 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
gen_server_config,
|
||||
self.MODEL_PATH,
|
||||
tensor_parallel_size=4) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
@ -836,10 +859,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@ -877,10 +897,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
gen_server_config,
|
||||
self.MODEL_PATH,
|
||||
tensor_parallel_size=4) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@ -961,10 +978,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(60000)
|
||||
@ -1017,8 +1031,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = JsonModeEval(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
|
||||
|
||||
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
@ -1071,10 +1084,7 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@ -1136,9 +1146,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
model_name = "GPT-OSS/120B-MXFP4"
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
run_accuracy_test(
|
||||
llm,
|
||||
model_name,
|
||||
test_sets=["GSM8K"],
|
||||
extra_evaluator_kwargs={"GSM8K": self.extra_evaluator_kwargs})
|
||||
|
||||
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
@ -1205,10 +1217,7 @@ class TestDeepSeekV32Exp(LlmapiAccuracyTestHarness):
|
||||
gen_server_config,
|
||||
self.MODEL_PATH,
|
||||
max_workers=128) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
@ -1247,8 +1256,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@ -1284,10 +1292,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
def test_chunked_prefill(self):
|
||||
# bs=1 will stabilize the result, but the test will be much slower
|
||||
@ -1325,10 +1330,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
self.MODEL_PATH) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@ -1364,7 +1366,6 @@ class TestKimiK2(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_device_memory(200000)
|
||||
def test_nvfp4(self):
|
||||
model_name = "moonshotai/Kimi-K2-Thinking"
|
||||
model_path = f"{llm_models_root()}/Kimi-K2-Thinking-NVFP4"
|
||||
ctx_server_config = {
|
||||
"max_batch_size": 16,
|
||||
@ -1408,5 +1409,4 @@ class TestKimiK2(LlmapiAccuracyTestHarness):
|
||||
with launch_disaggregated_llm(disaggregated_server_config,
|
||||
ctx_server_config, gen_server_config,
|
||||
model_path) as llm:
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm)
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user