[https://nvbugs/5558516][test] add disaggregated stress test (#9354)

Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
This commit is contained in:
xinhe-nv 2025-12-31 16:47:36 +08:00 committed by GitHub
parent 910a633066
commit 827d12caaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 473 additions and 66 deletions

View File

@ -32,7 +32,8 @@ from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.sampling_params import SamplingParams
from .trt_test_alternative import check_call, check_output, exists, is_windows
from .trt_test_alternative import (check_call, check_output, exists, is_windows,
print_info, print_warning)
def venv_check_call(venv, cmd, env=None, **kwargs):
@ -1229,3 +1230,32 @@ def revise_disagg_config_file_with_free_ports(disagg_config_file: str) -> str:
yaml.dump(new_config, f)
return new_config_file
def parse_gsm8k_output(output_text: str) -> float:
"""
Parse accuracy value from lm_eval output for GSM8K flexible-extract exact_match
Args:
output_text: The output text from gsm8k command
Returns:
float: The accuracy value (0.7582 in the example)
"""
# Look for the specific pattern: |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7559|± |0.0118|
patterns = [
r'flexible-extract\|\s+\d+\|exact_match\|\\s+\|(\d+\.\d+)',
]
for pattern in patterns:
match = re.search(pattern, output_text)
if match:
accuracy_value = float(match.group(1))
print_info(f"Extracted GSM8K accuracy value: {accuracy_value}")
return accuracy_value
print_warning("Could not find GSM8K accuracy value in gsm8k output")
print_warning(f"Output text: {output_text}")
return 0.0

View File

@ -0,0 +1,58 @@
model: gpt_oss/gpt-oss-120b
hostname: localhost
port: 8100
backend: pytorch
context_servers:
num_instances: 1
tensor_parallel_size: 2
pipeline_parallel_size: 1
moe_expert_parallel_size: 2
enable_attention_dp: false
max_num_tokens: 16640
max_seq_len: 8232
max_batch_size: 128
trust_remote_code: true
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config: null
print_iter_log: true
cache_transceiver_config:
backend: DEFAULT
max_tokens_in_buffer: 16384
urls:
- "localhost:8101"
generation_servers:
num_instances: 1
tensor_parallel_size: 2
pipeline_parallel_size: 1
moe_expert_parallel_size: 2
enable_attention_dp: false
max_num_tokens: 10240
max_seq_len: 10240
max_batch_size: 128
trust_remote_code: true
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config:
enable_padding: true
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
print_iter_log: true
cache_transceiver_config:
backend: DEFAULT
max_tokens_in_buffer: 16384
urls:
- "localhost:8102"

View File

@ -0,0 +1,58 @@
model: DeepSeek-R1/DeepSeek-R1-0528-FP4-v2
hostname: localhost
port: 8100
backend: pytorch
context_servers:
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
moe_expert_parallel_size: 4
enable_attention_dp: false
max_num_tokens: 16640
max_seq_len: 8232
max_batch_size: 128
trust_remote_code: true
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config: null
print_iter_log: true
cache_transceiver_config:
backend: DEFAULT
max_tokens_in_buffer: 16384
urls:
- "localhost:8101"
generation_servers:
num_instances: 1
tensor_parallel_size: 4
pipeline_parallel_size: 1
moe_expert_parallel_size: 4
enable_attention_dp: false
max_num_tokens: 10240
max_seq_len: 10240
max_batch_size: 128
trust_remote_code: true
enable_chunked_prefill: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config:
enable_padding: true
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 768, 1024]
print_iter_log: true
cache_transceiver_config:
backend: DEFAULT
max_tokens_in_buffer: 16384
urls:
- "localhost:8102"

View File

@ -18,6 +18,8 @@ import os
import re
import subprocess
import tempfile
import time
from dataclasses import dataclass
from typing import Callable
import pytest
@ -28,11 +30,13 @@ except ImportError:
import tensorrt_llm.ray_stub as ray
import yaml
from defs.common import (revise_disagg_config_file_with_free_ports,
from defs.common import (parse_gsm8k_output,
revise_disagg_config_file_with_free_ports,
wait_for_server)
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
skip_no_hopper)
from defs.trt_test_alternative import check_call, check_output, popen
skip_no_hopper, skip_pre_blackwell)
from defs.trt_test_alternative import (check_call, check_output, popen,
print_info)
from test_common.perf_metrics_utils import (get_timing_metrics,
validate_timing_metrics)
@ -40,6 +44,18 @@ from tensorrt_llm._utils import get_free_port, mpi_disabled
from tensorrt_llm.logger import logger
@dataclass
class TestConfig:
"""Configuration for disaggregated test."""
model_path: str
test_desc: str
request_count: int
accuracy_threshold: float
def __str__(self):
return self.test_desc
def cleanup_output_files():
"""Clean up output files from previous runs."""
for file in ['output.json', 'output_streaming.json']:
@ -177,6 +193,13 @@ def get_test_config(test_desc, example_dir, test_root):
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml"
),
"deepseek_r1_v2_fp4_stress":
(8,
f"{test_configs_root}/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml"
),
"gpt_oss_120b_stress":
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
}
if test_desc not in config_map:
@ -1611,15 +1634,42 @@ def get_config_for_benchmark(model_root, backend):
return serve_config
def run_disaggregated_genai_perf(config_file,
model_path,
num_ranks,
server_start_timeout=1200,
input_tokens=128000,
output_tokens=100,
env=None,
cwd=None):
"""Run disaggregated test with genai-perf for performance/stress testing."""
def run_disaggregated_aiperf(config_file,
model_path,
num_ranks,
server_start_timeout=1200,
input_tokens=128,
output_tokens=100,
concurrency=1,
endpoint_type='chat',
request_count=None,
warmup_request_count=10,
streaming=True,
random_seed=100,
accuracy_test=False,
threshold=0.8,
env=None,
cwd=None):
"""Run disaggregated test with genai-perf for performance/stress testing.
Args:
config_file: Path to disaggregated server config YAML
model_path: Path to model for tokenizer
num_ranks: Number of MPI ranks for workers
server_start_timeout: Timeout in seconds for server startup
input_tokens: Mean synthetic input tokens
output_tokens: Mean output tokens to generate
concurrency: Number of concurrent requests
endpoint_type: 'chat' or 'completions'
request_count: Total requests (if None, uses concurrency*1024 or num_dataset_entries)
warmup_request_count: Number of warmup requests
streaming: Whether to use streaming mode
random_seed: Random seed for reproducibility
accuracy_test: Whether to run accuracy test
threshold: Threshold for accuracy test
env: Environment variables dict
cwd: Working directory
"""
cleanup_output_files()
run_env = env.copy()
run_env["UCX_TLS"] = "^ib"
@ -1660,28 +1710,67 @@ def run_disaggregated_genai_perf(config_file,
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
)
# Run genai-perf
genai_perf_cmd = [
'genai-perf', 'profile', '--model', model_path, '--tokenizer',
model_path, '--endpoint-type', 'chat', '--endpoint',
'/v1/chat/completions', '--streaming', '--url',
f'{server_host}:{server_port}', '--synthetic-input-tokens-mean',
# Build base command (using aiperf instead of genai-perf)
aiperf_cmd = [
'aiperf', 'profile', '--model', model_path, '--tokenizer',
model_path, '--endpoint-type', endpoint_type
]
# Add endpoint path based on type
if endpoint_type == 'chat':
aiperf_cmd.extend(['--endpoint', '/v1/chat/completions'])
# Add streaming flag if enabled
if streaming:
aiperf_cmd.append('--streaming')
# Add common parameters
aiperf_cmd.extend([
'--url', f'{server_host}:{server_port}',
'--synthetic-input-tokens-mean',
str(input_tokens), '--synthetic-input-tokens-stddev', '0',
'--output-tokens-mean',
str(output_tokens), '--output-tokens-stddev', '0',
'--extra-inputs', f'max_tokens:{output_tokens}',
'--extra-inputs', f'min_tokens:{output_tokens}',
'--extra-inputs', 'ignore_eos:true', '--concurrency', '1',
'--warmup-request-count', '8', '--num-dataset-entries', '64',
'--random-seed', '100', '--artifact-dir', artifact_dir, '--',
'-v', '-H', 'Authorization: Bearer NOT USED', '-H',
'Accept: text/event-stream', '-p', '200000'
]
'--extra-inputs', 'ignore_eos:true', '--concurrency',
str(concurrency), '--warmup-request-count',
str(warmup_request_count)
])
check_call(genai_perf_cmd,
# Use request-count or num-dataset-entries
if request_count is not None:
aiperf_cmd.extend(['--request-count', str(request_count)])
else:
# Default: use num-dataset-entries for compatibility
aiperf_cmd.extend(['--num-dataset-entries', '64'])
aiperf_cmd.extend([
'--random-seed',
str(random_seed), '--artifact-dir', artifact_dir
])
# Run aiperf
check_call(aiperf_cmd,
env=env,
poll_procs=[workers_proc, server_proc])
if accuracy_test:
accuracy_test_result, accuracy_value = run_accuracy_test(
model_path=model_path,
server_url=f"http://{server_host}:{server_port}",
concurrency=concurrency,
max_retries=3,
timeout=1200,
max_gen_toks=256,
max_length=4096)
# only raise error if accuracy test passed and accuracy value is less than threshold
if accuracy_test_result and (accuracy_value < threshold):
raise AssertionError(
f"Accuracy test failed: accuracy value {accuracy_value} is less than test threshold {threshold}"
)
except Exception:
# Print outputs on error
logger.error("-------- Workers output (last 30 lines) --------")
@ -1711,6 +1800,105 @@ def run_disaggregated_genai_perf(config_file,
workers_proc.wait()
def run_accuracy_test(model_path: str, server_url: str, concurrency: int,
max_retries: int, timeout: int, max_gen_toks: int,
max_length: int) -> tuple[bool, float]:
"""
Run accuracy test using lm_eval with GSM8K dataset
Args:
model_path: Path of the model being tested
server_config: Server configuration containing URL and port
concurrency: Concurrency for accuracy tests
max_retries: Max retries for accuracy tests
timeout: Timeout for accuracy tests
max_gen_toks: Max generation tokens for accuracy tests
max_length: Max length for accuracy tests
Returns:
tuple: (Boolean indicating whether the accuracy test completed successfully, accuracy value)
"""
logger.info(f"=== Running ACCURACY TEST (GSM8K) ===")
tmp_dir = tempfile.TemporaryDirectory()
tmp_gsm8k_local_config = os.path.join(tmp_dir.name, "gsm8k_local.yaml")
gsm8k_local_config_path = os.path.join(
os.path.dirname(__file__), '../../lm_eval_configs/gsm8k_local.yaml')
with open(gsm8k_local_config_path, 'r', encoding='utf-8') as f:
config_content = f.read()
# Replace LLM_MODELS_ROOT with actual path
config_content = config_content.replace('LLM_MODELS_ROOT',
llm_models_root())
# Write modified config to temp file
with open(tmp_gsm8k_local_config, 'w', encoding='utf-8') as f:
f.write(config_content)
# Create lm_eval command
lm_eval_cmd = [
"lm_eval",
"--model",
"local-completions",
"--tasks",
"gsm8k_local",
"--include_path",
tmp_dir.name,
"--model_args",
f"model={model_path},base_url={server_url}/v1/completions,"
f"num_concurrent={concurrency},"
f"max_retries={max_retries},"
f"tokenized_requests=False,"
f"timeout={timeout},"
f"max_gen_toks={max_gen_toks},"
f"max_length={max_length}",
]
test_start_time = time.time()
accuracy_value = 0.0
try:
# Run lm_eval process with timeout monitoring
print_info(f"Running lm_eval command: {' '.join(lm_eval_cmd)}")
# Use subprocess.run to capture output directly
result = subprocess.run(lm_eval_cmd,
capture_output=True,
text=True,
timeout=timeout)
print_info(f"Accuracy test result is: {result}")
# Check if process completed successfully
if result.returncode == 0:
test_end_time = time.time()
duration = int(test_end_time - test_start_time)
logger.info(
f"Accuracy test completed successfully in {duration} seconds")
# Parse accuracy value from output
output_text = result.stdout
accuracy_value = parse_gsm8k_output(output_text)
if accuracy_value is not None:
return True, accuracy_value
else:
return False, accuracy_value
else:
logger.warning(
f"lm_eval exited with non-zero code: {result.returncode}")
logger.warning(f"stderr: {result.stderr}")
return False, accuracy_value
except subprocess.TimeoutExpired:
logger.warning(f"Accuracy test timed out after {timeout} seconds")
return False, accuracy_value
except Exception as e:
logger.warning(f"Error during accuracy test: {str(e)}")
return False, accuracy_value
@pytest.mark.parametrize("benchmark_model_root", [
'DeepSeek-V3-Lite-fp8', 'DeepSeek-V3-Lite-bf16', 'llama-v3-8b-hf',
'llama-3.1-8b-instruct-hf-fp8'
@ -1824,14 +2012,14 @@ def test_llama4_long_context_kv_cache_overflow(disaggregated_test_root,
disaggregated_example_root,
os.path.dirname(__file__))
run_disaggregated_genai_perf(config_file=config_file,
model_path=llama4_model_root,
num_ranks=num_ranks,
server_start_timeout=1200,
input_tokens=128000,
output_tokens=100,
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
run_disaggregated_aiperf(config_file=config_file,
model_path=llama4_model_root,
num_ranks=num_ranks,
server_start_timeout=1200,
input_tokens=128000,
output_tokens=100,
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.skip_less_device(4)
@ -1854,3 +2042,59 @@ def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix(
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory(),
prompt_file="long_prompts.json")
@pytest.mark.timeout(12600)
@pytest.mark.parametrize("test_config", [
pytest.param(TestConfig(model_path='DeepSeek-R1/DeepSeek-R1-0528-FP4-v2',
test_desc='deepseek_r1_v2_fp4_stress',
request_count=35000,
accuracy_threshold=0.92),
marks=(pytest.mark.skip_less_device(8), skip_pre_blackwell)),
pytest.param(TestConfig(model_path='gpt_oss/gpt-oss-120b',
test_desc='gpt_oss_120b_stress',
request_count=60000,
accuracy_threshold=0.42),
marks=(pytest.mark.skip_less_device(4), skip_pre_blackwell)),
],
ids=lambda x: x.test_desc)
@pytest.mark.parametrize("concurrency", [512], ids=lambda x: f"conc{x}")
@pytest.mark.parametrize("output_tokens", [1024],
ids=lambda x: f"output{x//1000}k")
@pytest.mark.parametrize("input_tokens", [8192],
ids=lambda x: f"input{x//1000}k")
def test_disaggregated_stress_test(disaggregated_test_root,
disaggregated_example_root, llm_venv,
test_config, input_tokens, output_tokens,
concurrency):
# Unpack configuration from dataclass
model_path = test_config.model_path
test_desc = test_config.test_desc
model_dir = f"{llm_models_root()}/{model_path}"
src_dst_dict = {
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
num_ranks, config_file = get_test_config(test_desc,
disaggregated_example_root,
os.path.dirname(__file__))
run_disaggregated_aiperf(config_file=config_file,
model_path=model_dir,
num_ranks=num_ranks,
server_start_timeout=7200,
input_tokens=input_tokens,
output_tokens=output_tokens,
concurrency=concurrency,
endpoint_type='completions',
request_count=test_config.request_count,
warmup_request_count=10,
streaming=False,
accuracy_test=True,
threshold=test_config.accuracy_threshold,
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())

View File

@ -44,6 +44,7 @@ import pandas as pd
import pytest
import requests
import yaml
from defs.common import parse_gsm8k_output
from defs.conftest import get_device_count, get_device_memory, llm_models_root
from defs.trt_test_alternative import (Popen, cleanup_process_tree, print_info,
print_warning)
@ -1067,35 +1068,6 @@ def format_time(seconds: int) -> str:
return f"{seconds}s"
def parse_accuracy_from_lm_eval_output(output_text: str) -> float:
"""
Parse accuracy value from lm_eval output for GSM8K flexible-extract exact_match
Args:
output_text: The output text from lm_eval command
Returns:
float: The accuracy value (0.7582 in the example)
"""
import re
# Look for the specific pattern: |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7559|± |0.0118|
patterns = [
r'flexible-extract\|\s+\d+\|exact_match\|\\s+\|(\d+\.\d+)',
]
for pattern in patterns:
match = re.search(pattern, output_text)
if match:
accuracy_value = float(match.group(1))
print_info(f"Extracted accuracy value: {accuracy_value}")
return accuracy_value
print_warning("Could not find accuracy value in lm_eval output")
print_warning(f"Output text: {output_text}")
return None
def run_accuracy_test(model_path: str,
server_config: ServerConfig,
stress_config: StressTestConfig,
@ -1155,7 +1127,7 @@ def run_accuracy_test(model_path: str,
# Parse accuracy value from output
output_text = result.stdout
accuracy_value = parse_accuracy_from_lm_eval_output(output_text)
accuracy_value = parse_gsm8k_output(output_text)
return True, accuracy_value
else:
print_warning(

View File

@ -3175,7 +3175,6 @@ def test_multi_nodes_eval(model_path, tp_size, pp_size, ep_size, eval_task,
f"--pp_size={pp_size}",
f"--kv_cache_free_gpu_memory_fraction={_MEM_FRACTION_80}",
"--max_batch_size=32",
"--enable_attention_dp",
"--backend=pytorch",
]

View File

@ -0,0 +1,44 @@
task: gsm8k_local
dataset_path: parquet
dataset_name: null
dataset_kwargs:
data_files:
test: LLM_MODELS_ROOT/datasets/openai/gsm8k/main/test-00000-of-00001.parquet
output_type: generate_until
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- "(?s).*#### "
- "\\.$"
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"

View File

@ -1,5 +1,7 @@
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-GUARANTEED_NO_EVICT-pytorch-stress-test-with-accuracy]
stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy]
disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-deepseek_r1_v2_fp4_stress]
disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_stress]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_fp8_8gpus
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_nvfp4_4gpus

View File

@ -383,7 +383,7 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8ep
accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] SKIP (https://nvbugs/5707145)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] SKIP (https://nvbugs/5596343)
unittest/_torch/speculative/test_spec_gate.py::test_spec_gate_e2e SKIP (https://nvbugs/5710045)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5569696)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5769815)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5715568)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5715568)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5721661)