[TRTLLM-7048][feat] add benchmark TRT flow test for MIG (#6884)

Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
xinhe-nv 2025-08-15 14:01:05 +08:00 committed by GitHub
parent 54ffc6a250
commit c03ea1ba2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -449,7 +449,9 @@ class BenchRunner:
skip_engine_build: bool = False,
quant: Optional[str] = None,
extra_llm_api_options: Optional[str] = None,
use_mpirun: bool = False):
use_mpirun: bool = False,
concurrency: Optional[int] = None,
num_requests: int = 10):
llm_models = llm_models_root()
assert llm_models is not None
@ -474,12 +476,14 @@ class BenchRunner:
else:
self.mpirun_cmd = ""
self.engine_path = None
self.concurrency = concurrency
self.num_requests = num_requests
def __call__(self):
self.prepare_dataset()
if not (self.skip_engine_build or self.use_pytorch_backend):
self.build_engine()
self.run_bench()
return self.run_bench()
def prepare_dataset(self):
dataset_tool = Path(self.llm_root, "benchmarks", "cpp",
@ -502,7 +506,7 @@ class BenchRunner:
"--output-stdev",
"0",
"--num-requests",
"10",
str(self.num_requests),
]
print(f"Running command: {' '.join(command)}")
dataset_output = self.llm_venv.run_cmd(
@ -556,7 +560,47 @@ class BenchRunner:
if self.extra_llm_api_options:
benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}"
check_call(benchmark_cmd, shell=True, env=self.llm_venv._new_env)
if self.concurrency:
benchmark_cmd += f" --concurrency {self.concurrency}"
if self.num_requests:
benchmark_cmd += f" --num_requests {self.num_requests}"
benchmark_output = check_output(benchmark_cmd,
shell=True,
env=self.llm_venv._new_env)
return self.parse_benchmark_output(benchmark_output)
def parse_benchmark_output(self, output):
"""Parse the benchmark output to extract key metrics."""
result = {
'concurrency': self.concurrency,
'num_requests': self.num_requests,
'throughput': 0,
'latency': 0
}
lines = output.split('\n')
for line in lines:
line = line.strip()
if 'total token throughput' in line.lower(
) and 'tokens/sec' in line.lower():
try:
throughput = line.split(":")[1].strip()
result['throughput'] = throughput
except (IndexError, ValueError) as e:
print(
f"Failed to parse throughput from line: {line}. Error: {e}"
)
elif 'total latency' in line.lower() and 'ms' in line.lower():
try:
latency = line.split(":")[1].strip()
result['latency'] = latency
except (IndexError, ValueError) as e:
print(
f"Failed to parse latency from line: {line}. Error: {e}"
)
return result
@pytest.mark.parametrize("model_name", ["meta-llama/Meta-Llama-3-8B-Instruct"],
@ -579,6 +623,67 @@ def test_trtllm_bench_llmapi_launch(llm_root, llm_venv, model_name,
runner()
@skip_pre_hopper
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name", ["meta/Meta-Llama-3.1-8B"],
ids=["llama3_1-8b"])
@pytest.mark.parametrize("model_subdir", ["llama-3.1-model/Meta-Llama-3.1-8B"],
ids=["llama_v3_1"])
@pytest.mark.parametrize("use_pytorch_backend", [False], ids=["trt_backend"])
def test_trtllm_bench_mig_launch(llm_root, llm_venv, model_name, model_subdir,
use_pytorch_backend):
"run bench mark in MIG mode, check if the throughput is increasing by concurrency"
skip_engine_build = False
results = {}
concurrency_list = [1, 32, 64, 128]
for concurrency in concurrency_list:
num_requests = concurrency * 10
runner = BenchRunner(llm_root=llm_root,
llm_venv=llm_venv,
model_name=model_name,
model_subdir=model_subdir,
streaming=False,
use_pytorch_backend=use_pytorch_backend,
use_mpirun=False,
tp_size=1,
concurrency=concurrency,
num_requests=num_requests,
skip_engine_build=skip_engine_build)
output = runner()
results[concurrency] = output
print(f"\n=== Benchmark Results Comparison ===")
print(f"Model: {model_name}")
print(f"Backend: {'PyTorch' if use_pytorch_backend else 'TensorRT'}")
print(
f"{'Concurrency':<15} {'Throughput':<15} {'Latency':<15} {'Num Requests':<15}"
)
print("-" * 60)
for idx, val in enumerate(concurrency_list):
metrics = results.get(val)
if not isinstance(metrics, dict):
pytest.fail(
f"Unexpected benchmark result type for concurrency {val}: {type(metrics)}"
)
try:
throughput = float(metrics.get('throughput', 0))
latency = float(metrics.get('latency', 0))
num_requests = int(metrics.get('num_requests', 0))
except (ValueError, TypeError) as e:
pytest.fail(
f"Failed to parse benchmark results for concurrency {val}: {e}")
assert throughput > 0, f"Throughput is 0 for concurrency {val}"
assert latency > 0, f"Latency is 0 for concurrency {val}"
print(f"{val:<15} {throughput:<15} {latency:<15} {num_requests:<15}")
if idx > 0:
prev_throughput = float(results[concurrency_list[idx - 1]].get(
'throughput', 0))
assert throughput > prev_throughput * 1.3, f"Throughput is not increasing for concurrency {concurrency_list[idx]}"
@pytest.mark.parametrize(
"model_name, llama_model_root",
[pytest.param("TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0")],