mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7048][feat] add benchmark TRT flow test for MIG (#6884)
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com> Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
54ffc6a250
commit
c03ea1ba2d
@ -449,7 +449,9 @@ class BenchRunner:
|
||||
skip_engine_build: bool = False,
|
||||
quant: Optional[str] = None,
|
||||
extra_llm_api_options: Optional[str] = None,
|
||||
use_mpirun: bool = False):
|
||||
use_mpirun: bool = False,
|
||||
concurrency: Optional[int] = None,
|
||||
num_requests: int = 10):
|
||||
|
||||
llm_models = llm_models_root()
|
||||
assert llm_models is not None
|
||||
@ -474,12 +476,14 @@ class BenchRunner:
|
||||
else:
|
||||
self.mpirun_cmd = ""
|
||||
self.engine_path = None
|
||||
self.concurrency = concurrency
|
||||
self.num_requests = num_requests
|
||||
|
||||
def __call__(self):
|
||||
self.prepare_dataset()
|
||||
if not (self.skip_engine_build or self.use_pytorch_backend):
|
||||
self.build_engine()
|
||||
self.run_bench()
|
||||
return self.run_bench()
|
||||
|
||||
def prepare_dataset(self):
|
||||
dataset_tool = Path(self.llm_root, "benchmarks", "cpp",
|
||||
@ -502,7 +506,7 @@ class BenchRunner:
|
||||
"--output-stdev",
|
||||
"0",
|
||||
"--num-requests",
|
||||
"10",
|
||||
str(self.num_requests),
|
||||
]
|
||||
print(f"Running command: {' '.join(command)}")
|
||||
dataset_output = self.llm_venv.run_cmd(
|
||||
@ -556,7 +560,47 @@ class BenchRunner:
|
||||
|
||||
if self.extra_llm_api_options:
|
||||
benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}"
|
||||
check_call(benchmark_cmd, shell=True, env=self.llm_venv._new_env)
|
||||
if self.concurrency:
|
||||
benchmark_cmd += f" --concurrency {self.concurrency}"
|
||||
if self.num_requests:
|
||||
benchmark_cmd += f" --num_requests {self.num_requests}"
|
||||
|
||||
benchmark_output = check_output(benchmark_cmd,
|
||||
shell=True,
|
||||
env=self.llm_venv._new_env)
|
||||
return self.parse_benchmark_output(benchmark_output)
|
||||
|
||||
def parse_benchmark_output(self, output):
|
||||
"""Parse the benchmark output to extract key metrics."""
|
||||
result = {
|
||||
'concurrency': self.concurrency,
|
||||
'num_requests': self.num_requests,
|
||||
'throughput': 0,
|
||||
'latency': 0
|
||||
}
|
||||
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if 'total token throughput' in line.lower(
|
||||
) and 'tokens/sec' in line.lower():
|
||||
try:
|
||||
throughput = line.split(":")[1].strip()
|
||||
result['throughput'] = throughput
|
||||
except (IndexError, ValueError) as e:
|
||||
print(
|
||||
f"Failed to parse throughput from line: {line}. Error: {e}"
|
||||
)
|
||||
elif 'total latency' in line.lower() and 'ms' in line.lower():
|
||||
try:
|
||||
latency = line.split(":")[1].strip()
|
||||
result['latency'] = latency
|
||||
except (IndexError, ValueError) as e:
|
||||
print(
|
||||
f"Failed to parse latency from line: {line}. Error: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["meta-llama/Meta-Llama-3-8B-Instruct"],
|
||||
@ -579,6 +623,67 @@ def test_trtllm_bench_llmapi_launch(llm_root, llm_venv, model_name,
|
||||
runner()
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name", ["meta/Meta-Llama-3.1-8B"],
|
||||
ids=["llama3_1-8b"])
|
||||
@pytest.mark.parametrize("model_subdir", ["llama-3.1-model/Meta-Llama-3.1-8B"],
|
||||
ids=["llama_v3_1"])
|
||||
@pytest.mark.parametrize("use_pytorch_backend", [False], ids=["trt_backend"])
|
||||
def test_trtllm_bench_mig_launch(llm_root, llm_venv, model_name, model_subdir,
|
||||
use_pytorch_backend):
|
||||
"run bench mark in MIG mode, check if the throughput is increasing by concurrency"
|
||||
skip_engine_build = False
|
||||
results = {}
|
||||
concurrency_list = [1, 32, 64, 128]
|
||||
|
||||
for concurrency in concurrency_list:
|
||||
num_requests = concurrency * 10
|
||||
runner = BenchRunner(llm_root=llm_root,
|
||||
llm_venv=llm_venv,
|
||||
model_name=model_name,
|
||||
model_subdir=model_subdir,
|
||||
streaming=False,
|
||||
use_pytorch_backend=use_pytorch_backend,
|
||||
use_mpirun=False,
|
||||
tp_size=1,
|
||||
concurrency=concurrency,
|
||||
num_requests=num_requests,
|
||||
skip_engine_build=skip_engine_build)
|
||||
|
||||
output = runner()
|
||||
results[concurrency] = output
|
||||
|
||||
print(f"\n=== Benchmark Results Comparison ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Backend: {'PyTorch' if use_pytorch_backend else 'TensorRT'}")
|
||||
print(
|
||||
f"{'Concurrency':<15} {'Throughput':<15} {'Latency':<15} {'Num Requests':<15}"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
for idx, val in enumerate(concurrency_list):
|
||||
metrics = results.get(val)
|
||||
if not isinstance(metrics, dict):
|
||||
pytest.fail(
|
||||
f"Unexpected benchmark result type for concurrency {val}: {type(metrics)}"
|
||||
)
|
||||
try:
|
||||
throughput = float(metrics.get('throughput', 0))
|
||||
latency = float(metrics.get('latency', 0))
|
||||
num_requests = int(metrics.get('num_requests', 0))
|
||||
except (ValueError, TypeError) as e:
|
||||
pytest.fail(
|
||||
f"Failed to parse benchmark results for concurrency {val}: {e}")
|
||||
assert throughput > 0, f"Throughput is 0 for concurrency {val}"
|
||||
assert latency > 0, f"Latency is 0 for concurrency {val}"
|
||||
print(f"{val:<15} {throughput:<15} {latency:<15} {num_requests:<15}")
|
||||
if idx > 0:
|
||||
prev_throughput = float(results[concurrency_list[idx - 1]].get(
|
||||
'throughput', 0))
|
||||
assert throughput > prev_throughput * 1.3, f"Throughput is not increasing for concurrency {concurrency_list[idx]}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, llama_model_root",
|
||||
[pytest.param("TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0")],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user