mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
780 lines
32 KiB
Python
780 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Set
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
|
|
class BenchmarkRunner:
|
|
|
|
def __init__(self,
|
|
output_folder: str,
|
|
config_file: str,
|
|
skip_pattern: str = None,
|
|
select_pattern: str = None):
|
|
self.output_folder = Path(output_folder)
|
|
self.config_file = Path(config_file)
|
|
|
|
# Treat empty or "default" values as None (default behavior)
|
|
self.skip_pattern = None if not skip_pattern or skip_pattern.lower(
|
|
) == "default" else skip_pattern
|
|
self.select_pattern = None if not select_pattern or select_pattern.lower(
|
|
) == "default" else select_pattern
|
|
|
|
self.skip_test_cases: Set[int] = set()
|
|
self.skip_concurrencies: Dict[int, Set[int]] = {}
|
|
self.select_test_cases: Set[int] = set()
|
|
self.select_concurrencies: Dict[int, Set[int]] = {}
|
|
|
|
if self.skip_pattern:
|
|
self.parse_skip_pattern(self.skip_pattern)
|
|
|
|
if self.select_pattern:
|
|
self.parse_select_pattern(self.select_pattern)
|
|
|
|
# Execution plan: {test_case_id: [concurrency_indices]}
|
|
self.execution_plan: Dict[int, List[int]] = {}
|
|
|
|
# Model path mapping
|
|
self.model_paths = {
|
|
"70B-FP4":
|
|
"/home/scratch.trt_llm_data/llm-models/llama-3.3-models/Llama-3.3-70B-Instruct-FP4",
|
|
"70B-FP8":
|
|
"/home/scratch.trt_llm_data/llm-models/llama-3.3-models/Llama-3.3-70B-Instruct-FP8",
|
|
"Scout-FP4":
|
|
"/home/scratch.trt_llm_data/llm-models/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP4",
|
|
"Scout-FP8":
|
|
"/home/scratch.trt_llm_data/llm-models/llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8",
|
|
"R1-FP8":
|
|
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1/",
|
|
"R1-FP4":
|
|
"/home/scratch.trt_llm_data/llm-models/DeepSeek-R1/DeepSeek-R1-0528-FP4"
|
|
}
|
|
|
|
# Set environment variables
|
|
os.environ['TQDM_MININTERVAL'] = '1000'
|
|
os.environ['PRINT_ITER_LOG'] = 'false'
|
|
|
|
# Capture system information
|
|
self.node_name = self.get_node_name()
|
|
self.gpu_info = self.get_gpu_info()
|
|
|
|
# Change to output directory
|
|
os.chdir(self.output_folder)
|
|
|
|
def get_node_name(self) -> str:
|
|
"""Get the current node name"""
|
|
try:
|
|
result = subprocess.run("hostname",
|
|
shell=True,
|
|
capture_output=True,
|
|
text=True,
|
|
check=True)
|
|
return result.stdout.strip()
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
return "unknown"
|
|
|
|
def get_gpu_info(self) -> str:
|
|
"""Get GPU information from nvidia-smi"""
|
|
try:
|
|
result = subprocess.run("nvidia-smi",
|
|
shell=True,
|
|
capture_output=True,
|
|
text=True,
|
|
check=True)
|
|
return result.stdout
|
|
except subprocess.CalledProcessError as e:
|
|
return f"nvidia-smi failed with error code {e.returncode}\nError output: {e.stderr}"
|
|
except FileNotFoundError:
|
|
return "nvidia-smi not found"
|
|
|
|
def parse_skip_pattern(self, skip_pattern: str) -> None:
|
|
"""Parse skip pattern like '2,4-1' to determine what to skip"""
|
|
if not skip_pattern:
|
|
return
|
|
|
|
parts = skip_pattern.split(',')
|
|
for part in parts:
|
|
part = part.strip()
|
|
if not part: # Skip empty parts
|
|
continue
|
|
|
|
if '-' in part:
|
|
# Format: "test_case-concurrency_index" (1-based)
|
|
try:
|
|
test_case_str, concurrency_str = part.split('-')
|
|
test_case_id = int(test_case_str)
|
|
concurrency_index = int(
|
|
concurrency_str) - 1 # Convert to 0-based
|
|
|
|
if test_case_id not in self.skip_concurrencies:
|
|
self.skip_concurrencies[test_case_id] = set()
|
|
self.skip_concurrencies[test_case_id].add(concurrency_index)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid skip pattern '{part}'. Expected format: 'test_case-concurrency_index' (e.g., '2-1')"
|
|
)
|
|
else:
|
|
# Format: "test_case" - skip entire test case
|
|
try:
|
|
test_case_id = int(part)
|
|
self.skip_test_cases.add(test_case_id)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid test case ID '{part}' in skip pattern. Must be a valid integer."
|
|
)
|
|
|
|
print(f"Skipping test cases: {sorted(self.skip_test_cases)}")
|
|
print(f"Skipping concurrencies: {self.skip_concurrencies}")
|
|
|
|
def parse_select_pattern(self, select_pattern: str) -> None:
|
|
"""Parse select pattern like '1,3,5' or '1-1,2-3' to determine which test cases/concurrencies to run"""
|
|
if not select_pattern:
|
|
return
|
|
|
|
self.select_concurrencies: Dict[int, Set[int]] = {}
|
|
|
|
parts = select_pattern.split(',')
|
|
for part in parts:
|
|
part = part.strip()
|
|
if not part: # Skip empty parts
|
|
continue
|
|
|
|
if '-' in part:
|
|
# Format: "test_case-concurrency_index" (1-based)
|
|
try:
|
|
test_case_str, concurrency_str = part.split('-')
|
|
test_case_id = int(test_case_str)
|
|
concurrency_index = int(
|
|
concurrency_str) - 1 # Convert to 0-based
|
|
|
|
if test_case_id not in self.select_concurrencies:
|
|
self.select_concurrencies[test_case_id] = set()
|
|
self.select_concurrencies[test_case_id].add(
|
|
concurrency_index)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid select pattern '{part}'. Expected format: 'test_case-concurrency_index' (e.g., '2-1')"
|
|
)
|
|
else:
|
|
# Format: "test_case" - select entire test case
|
|
try:
|
|
test_case_id = int(part)
|
|
self.select_test_cases.add(test_case_id)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid test case ID '{part}' in select pattern. Must be a valid integer."
|
|
)
|
|
|
|
print(f"Selected test cases: {sorted(self.select_test_cases)}")
|
|
print(f"Selected concurrencies: {self.select_concurrencies}")
|
|
|
|
def build_execution_plan(self, test_cases: List[Dict[str, Any]]) -> None:
|
|
"""Build execution plan by analyzing config file, skip_pattern, and select_pattern"""
|
|
self.execution_plan.clear()
|
|
|
|
# Step 1: Initialize execution plan based on select_pattern
|
|
if not self.select_pattern:
|
|
# If select_pattern is empty or default, include all test cases with all concurrencies
|
|
for test_case in test_cases:
|
|
test_case_id = test_case['id']
|
|
all_concurrencies = list(
|
|
range(len(test_case['concurrency_iterations'])))
|
|
self.execution_plan[test_case_id] = all_concurrencies
|
|
else:
|
|
# If select_pattern is specified, only include selected test cases and concurrencies
|
|
for test_case in test_cases:
|
|
test_case_id = test_case['id']
|
|
|
|
# Check if this test case is selected
|
|
if test_case_id in self.select_test_cases:
|
|
# Test case is selected - include all concurrencies
|
|
all_concurrencies = list(
|
|
range(len(test_case['concurrency_iterations'])))
|
|
self.execution_plan[test_case_id] = all_concurrencies
|
|
elif test_case_id in self.select_concurrencies:
|
|
# Specific concurrencies are selected for this test case
|
|
selected_concurrencies = list(
|
|
self.select_concurrencies[test_case_id])
|
|
# Validate that selected concurrencies exist in config
|
|
max_concurrency_index = len(
|
|
test_case['concurrency_iterations']) - 1
|
|
valid_concurrencies = [
|
|
c for c in selected_concurrencies
|
|
if 0 <= c <= max_concurrency_index
|
|
]
|
|
if valid_concurrencies:
|
|
self.execution_plan[test_case_id] = valid_concurrencies
|
|
|
|
# Step 2: Apply skip_pattern to remove test cases and concurrencies
|
|
# Remove entire test cases that are in skip_test_cases
|
|
for test_case_id in self.skip_test_cases:
|
|
if test_case_id in self.execution_plan:
|
|
del self.execution_plan[test_case_id]
|
|
|
|
# Remove specific concurrencies that are in skip_concurrencies
|
|
for test_case_id, skip_concurrency_indices in self.skip_concurrencies.items(
|
|
):
|
|
if test_case_id in self.execution_plan:
|
|
# Remove skipped concurrencies from the list
|
|
remaining_concurrencies = [
|
|
c for c in self.execution_plan[test_case_id]
|
|
if c not in skip_concurrency_indices
|
|
]
|
|
if remaining_concurrencies:
|
|
self.execution_plan[test_case_id] = remaining_concurrencies
|
|
else:
|
|
# If no concurrencies remain, remove the entire test case
|
|
del self.execution_plan[test_case_id]
|
|
|
|
# Step 3: Clean up - remove test cases with empty concurrency lists
|
|
# (This should not happen with the above logic, but just to be safe)
|
|
test_cases_to_remove = []
|
|
for test_case_id, concurrencies in self.execution_plan.items():
|
|
if not concurrencies:
|
|
test_cases_to_remove.append(test_case_id)
|
|
|
|
for test_case_id in test_cases_to_remove:
|
|
del self.execution_plan[test_case_id]
|
|
|
|
def print_execution_plan(self, test_cases: List[Dict[str, Any]]) -> None:
|
|
"""Print which test cases and concurrencies will be executed"""
|
|
print("\n" + "=" * 80)
|
|
print("EXECUTION PLAN")
|
|
print("=" * 80)
|
|
|
|
total_test_cases = 0
|
|
total_concurrencies = 0
|
|
|
|
for test_case in test_cases:
|
|
test_case_id = test_case['id']
|
|
model_label = test_case['model']
|
|
|
|
# Check if this test case is in execution plan
|
|
if test_case_id not in self.execution_plan:
|
|
print(f"Test Case {test_case_id}: {model_label} - SKIPPED")
|
|
continue
|
|
|
|
total_test_cases += 1
|
|
print(f"\nTest Case {test_case_id}: {model_label}")
|
|
print(
|
|
f" Config: GPUs={test_case['gpus']}, TP={test_case['tp']}, EP={test_case['ep']}, attn_backend={test_case['attn_backend']}, moe_backend={test_case['moe_backend']}"
|
|
)
|
|
|
|
# Get concurrencies from execution plan
|
|
concurrencies_to_run = []
|
|
for concurrency_index in self.execution_plan[test_case_id]:
|
|
concurrency, iteration = test_case['concurrency_iterations'][
|
|
concurrency_index]
|
|
concurrencies_to_run.append(
|
|
(concurrency_index + 1, concurrency,
|
|
iteration)) # +1 for 1-based display
|
|
total_concurrencies += 1
|
|
|
|
print(
|
|
f" Concurrencies to run ({len(concurrencies_to_run)}/{len(test_case['concurrency_iterations'])}):"
|
|
)
|
|
for concurrency_num, concurrency, iteration in concurrencies_to_run:
|
|
print(
|
|
f" {concurrency_num}. Concurrency={concurrency}, Iteration={iteration}"
|
|
)
|
|
|
|
print("\n" + "=" * 80)
|
|
print(
|
|
f"SUMMARY: {total_test_cases} test cases, {total_concurrencies} concurrencies will be executed"
|
|
)
|
|
print("=" * 80 + "\n")
|
|
|
|
def generate_extra_llm_api_config(self, test_case: Dict[str, Any]) -> str:
|
|
"""Generate extra-llm-api-config.yml content"""
|
|
config_lines = [
|
|
"print_iter_log: true",
|
|
f"enable_attention_dp: {str(test_case['enable_attention_dp']).lower()}",
|
|
"disable_overlap_scheduler: false",
|
|
"stream_interval: 10",
|
|
f"attn_backend: {test_case['attn_backend']}",
|
|
"cuda_graph_config:",
|
|
" enable_padding: true",
|
|
f" max_batch_size: {test_case['max_batch_size']}",
|
|
"kv_cache_config:",
|
|
" dtype: fp8",
|
|
f" free_gpu_memory_fraction: {test_case['free_gpu_mem_fraction']}",
|
|
" enable_block_reuse: false",
|
|
]
|
|
|
|
# Add moe_config if moe_backend is specified
|
|
if test_case['moe_backend']:
|
|
config_lines.append("moe_config:")
|
|
config_lines.append(f" backend: {test_case['moe_backend']}")
|
|
|
|
if test_case['moe_max_num_tokens']:
|
|
config_lines.append(
|
|
f" max_num_tokens: {test_case['moe_max_num_tokens']}")
|
|
|
|
return "\n".join(config_lines)
|
|
|
|
def wait_for_server(self,
|
|
server_pid: int,
|
|
server_log_filename: str,
|
|
max_attempts: int = 360) -> bool:
|
|
"""Wait for server to be ready"""
|
|
print("Waiting for trtllm-serve to be ready...")
|
|
|
|
for attempt in range(1, max_attempts + 1):
|
|
# Check if server is still running
|
|
try:
|
|
os.kill(server_pid, 0) # Check if process exists
|
|
except OSError:
|
|
print("Error: Server process has died")
|
|
return False
|
|
|
|
# Check server log for runtime errors
|
|
if self.check_for_runtime_error(server_log_filename):
|
|
print(
|
|
f"RuntimeError detected in server log: {server_log_filename}"
|
|
)
|
|
print("Killing server process due to runtime error")
|
|
try:
|
|
subprocess.run(f"kill -9 {server_pid}",
|
|
shell=True,
|
|
check=False)
|
|
subprocess.run(f"wait {server_pid} 2>/dev/null || true",
|
|
shell=True,
|
|
check=False)
|
|
except Exception as e:
|
|
print(f"Warning: Error killing server process: {e}")
|
|
return False
|
|
|
|
# Try to connect to server
|
|
try:
|
|
response = requests.get("http://localhost:8000/v1/models",
|
|
timeout=5)
|
|
if response.status_code == 200:
|
|
print(
|
|
f"Server is ready! HTTP status: {response.status_code}")
|
|
return True
|
|
except requests.RequestException:
|
|
pass
|
|
|
|
print(
|
|
f"Attempt {attempt}/{max_attempts}: Server not ready yet, waiting..."
|
|
)
|
|
time.sleep(10)
|
|
|
|
print(
|
|
f"Error: Server did not become ready after {max_attempts} attempts")
|
|
return False
|
|
|
|
def check_for_runtime_error(self, log_file_path: str) -> bool:
|
|
"""Check if RuntimeError exists in log file"""
|
|
try:
|
|
if os.path.exists(log_file_path):
|
|
with open(log_file_path, 'r') as f:
|
|
content = f.read()
|
|
if "RuntimeError" in content or "runtime error" in content or "illegal memory access" in content or "terminate called" in content:
|
|
return True
|
|
except Exception as e:
|
|
print(f"Warning: Could not read log file {log_file_path}: {e}")
|
|
return False
|
|
|
|
def run_benchmark(self, test_case: Dict[str, Any], concurrency: int,
|
|
iteration: int, model_path: str,
|
|
server_log_filename: str) -> bool:
|
|
"""Run a single benchmark with monitoring. Returns True if successful, False if should skip test case"""
|
|
num_prompts = concurrency * iteration
|
|
|
|
print(
|
|
f'Running benchmark with concurrency: {concurrency}, iteration: {iteration}, num-prompts: {num_prompts}'
|
|
)
|
|
|
|
# Build benchmark command
|
|
benchmark_cmd = [
|
|
"python", "-m", "tensorrt_llm.serve.scripts.benchmark_serving",
|
|
"--model", model_path, "--dataset-name", "random", "--random-ids",
|
|
"--num-prompts",
|
|
str(num_prompts), "--random-input-len",
|
|
str(test_case['isl']), "--random-output-len",
|
|
str(test_case['osl']), "--random-range-ratio", "0.0",
|
|
"--ignore-eos", "--percentile-metrics", "ttft,tpot,itl,e2el",
|
|
"--max-concurrency",
|
|
str(concurrency)
|
|
]
|
|
|
|
print(f'Running benchmark with command:')
|
|
print(' '.join(benchmark_cmd))
|
|
print()
|
|
|
|
# Prepare log filename
|
|
benchmark_log_filename = (
|
|
f"serve.{test_case['model']}.tp{test_case['tp']}.ep{test_case['ep']}."
|
|
f"attn{test_case['attn_backend']}.moe{test_case['moe_backend']}."
|
|
f"gpu{test_case['free_gpu_mem_fraction']}.batch{test_case['max_batch_size']}."
|
|
f"isl{test_case['isl']}.osl{test_case['osl']}."
|
|
f"tokens{test_case['max_num_tokens']}.moetokens{test_case['moe_max_num_tokens']}."
|
|
f"concurrency{concurrency}.iter{iteration}.log")
|
|
|
|
try:
|
|
with open(benchmark_log_filename, 'w') as f:
|
|
f.write(f"GPU Info: {self.gpu_info}\n")
|
|
|
|
# Start benchmark as subprocess
|
|
with open(benchmark_log_filename, 'a') as log_file:
|
|
benchmark_process = subprocess.Popen(benchmark_cmd,
|
|
stdout=log_file,
|
|
stderr=subprocess.STDOUT)
|
|
|
|
# Monitor logs every 60 seconds with timeout
|
|
print(
|
|
f"Starting log monitoring for benchmark process (PID: {benchmark_process.pid})"
|
|
)
|
|
|
|
start_time = time.time()
|
|
timeout_seconds = 3600 # 1 hour timeout
|
|
|
|
while benchmark_process.poll() is None: # Process is still running
|
|
time.sleep(60) # Wait 60 seconds
|
|
|
|
# Check if benchmark has been running for more than 1 hour
|
|
elapsed_time = time.time() - start_time
|
|
if elapsed_time > timeout_seconds:
|
|
print(
|
|
f"Benchmark timeout after {elapsed_time:.0f} seconds (>{timeout_seconds} seconds)"
|
|
)
|
|
print("Killing benchmark process due to timeout")
|
|
try:
|
|
subprocess.run(f"kill -9 {benchmark_process.pid}",
|
|
shell=True,
|
|
check=False)
|
|
benchmark_process.wait(timeout=10)
|
|
except Exception as e:
|
|
print(f"Warning: Error killing benchmark process: {e}")
|
|
return False # Signal to skip test case
|
|
|
|
print(
|
|
f"Checking logs for RuntimeError... (benchmark PID: {benchmark_process.pid}, elapsed: {elapsed_time:.0f}s)"
|
|
)
|
|
|
|
# Check server log for RuntimeError
|
|
if self.check_for_runtime_error(server_log_filename):
|
|
print(
|
|
f"RuntimeError found in server log: {server_log_filename}"
|
|
)
|
|
print(
|
|
"Killing benchmark process and skipping this test case")
|
|
try:
|
|
subprocess.run(f"kill -9 {benchmark_process.pid}",
|
|
shell=True,
|
|
check=False)
|
|
benchmark_process.wait(timeout=10)
|
|
except Exception as e:
|
|
print(f"Warning: Error killing benchmark process: {e}")
|
|
return False # Signal to skip test case
|
|
|
|
# Check benchmark log for RuntimeError
|
|
if self.check_for_runtime_error(benchmark_log_filename):
|
|
print(
|
|
f"RuntimeError found in benchmark log: {benchmark_log_filename}"
|
|
)
|
|
print(
|
|
"Killing benchmark process and skipping this test case")
|
|
try:
|
|
subprocess.run(f"kill -9 {benchmark_process.pid}",
|
|
shell=True,
|
|
check=False)
|
|
benchmark_process.wait(timeout=10)
|
|
except Exception as e:
|
|
print(f"Warning: Error killing benchmark process: {e}")
|
|
return False # Signal to skip test case
|
|
|
|
# Process completed, check final return code
|
|
return_code = benchmark_process.returncode
|
|
if return_code != 0:
|
|
print(
|
|
f"Benchmark process completed with error code: {return_code}"
|
|
)
|
|
|
|
# Read and display error output
|
|
try:
|
|
with open(benchmark_log_filename, 'r') as f:
|
|
error_content = f.read()
|
|
print(
|
|
f"Benchmark error output:\n{error_content[-1000:]}"
|
|
) # Last 1000 chars
|
|
except Exception as e:
|
|
print(f"Could not read benchmark log: {e}")
|
|
|
|
print(
|
|
f"Skipping this concurrency level and continuing with next one..."
|
|
)
|
|
print("-----------------------------------------")
|
|
return True # Continue with next concurrency, don't skip test case
|
|
|
|
# Success case
|
|
print(
|
|
f"Benchmark completed successfully (PID: {benchmark_process.pid})"
|
|
)
|
|
|
|
# Add configuration summary to log file
|
|
config_summary = (
|
|
f"Completed benchmark with Configuration: "
|
|
f"model_label={test_case['model']}, GPUs={test_case['gpus']}, "
|
|
f"TP={test_case['tp']}, EP={test_case['ep']}, "
|
|
f"attn_backend={test_case['attn_backend']}, "
|
|
f"moe_backend={test_case['moe_backend']}, "
|
|
f"enable_attention_dp={test_case['enable_attention_dp']}, "
|
|
f"free_gpu_mem_fraction={test_case['free_gpu_mem_fraction']}, "
|
|
f"max_batch_size={test_case['max_batch_size']}, "
|
|
f"ISL={test_case['isl']}, OSL={test_case['osl']}, "
|
|
f"max_num_tokens={test_case['max_num_tokens']}, "
|
|
f"moe_max_num_tokens={test_case['moe_max_num_tokens']}, "
|
|
f"Concurrency={concurrency}")
|
|
with open(benchmark_log_filename, 'a') as f:
|
|
f.write(f"\n{config_summary}\n")
|
|
|
|
print("-----------------------------------------")
|
|
return True # Continue with next concurrency
|
|
|
|
except Exception as e:
|
|
print(
|
|
f"Error running benchmark with concurrency {concurrency}: {e}")
|
|
print(
|
|
f"Skipping this concurrency level and continuing with next one..."
|
|
)
|
|
print("-----------------------------------------")
|
|
return True # Continue with next concurrency, don't skip test case
|
|
|
|
def run_test_case(self, test_case: Dict[str, Any]) -> None:
|
|
"""Run a test case using the execution plan"""
|
|
model_label = test_case['model']
|
|
test_case_id = test_case['id']
|
|
|
|
# Get model path
|
|
model_path = self.model_paths.get(model_label)
|
|
if not model_path:
|
|
print(f"Error: No model path found for {model_label}")
|
|
return
|
|
|
|
# Use local path if it exists, otherwise use model name
|
|
if os.path.exists(model_path):
|
|
MODEL = model_path
|
|
else:
|
|
MODEL = model_label
|
|
|
|
# Generate extra-llm-api-config.yml
|
|
config_content = self.generate_extra_llm_api_config(test_case)
|
|
config_path = "/tmp/extra-llm-api-config.yml"
|
|
|
|
with open(config_path, 'w') as f:
|
|
f.write(config_content)
|
|
|
|
print("extra-llm-api-config.yml:")
|
|
print(config_content)
|
|
|
|
# Build trtllm-serve command
|
|
serve_cmd = [
|
|
"trtllm-serve", MODEL, "--backend", "pytorch", "--tp_size",
|
|
str(test_case['tp']), "--ep_size",
|
|
str(test_case['ep']), "--max_batch_size",
|
|
str(test_case['max_batch_size']), "--max_num_tokens",
|
|
str(test_case['max_num_tokens']),
|
|
"--kv_cache_free_gpu_memory_fraction",
|
|
str(test_case['free_gpu_mem_fraction']), "--extra_llm_api_options",
|
|
config_path
|
|
]
|
|
|
|
print("Starting trtllm-serve with command:")
|
|
print(' '.join(serve_cmd))
|
|
print()
|
|
|
|
# Start server
|
|
server_log_filename = (
|
|
f"trtllm-serve.{model_label}.tp{test_case['tp']}.ep{test_case['ep']}."
|
|
f"attn{test_case['attn_backend']}.moe{test_case['moe_backend']}."
|
|
f"gpu{test_case['free_gpu_mem_fraction']}.batch{test_case['max_batch_size']}."
|
|
f"isl{test_case['isl']}.osl{test_case['osl']}."
|
|
f"tokens{test_case['max_num_tokens']}.moetokens{test_case['moe_max_num_tokens']}.log"
|
|
)
|
|
|
|
try:
|
|
with open(server_log_filename, 'w') as log_file:
|
|
log_file.write(f"extra-llm-api-config.yml:\n")
|
|
log_file.write(config_content)
|
|
log_file.write("\n")
|
|
|
|
with open(server_log_filename, 'a') as log_file:
|
|
server_process = subprocess.Popen(serve_cmd,
|
|
stdout=log_file,
|
|
stderr=subprocess.STDOUT)
|
|
|
|
# Wait for server to be ready
|
|
if not self.wait_for_server(server_process.pid,
|
|
server_log_filename):
|
|
print(
|
|
"Failed to start server, killing process and skipping this test case"
|
|
)
|
|
try:
|
|
subprocess.run(f"kill -9 {server_process.pid}",
|
|
shell=True,
|
|
check=False)
|
|
subprocess.run(
|
|
f"wait {server_process.pid} 2>/dev/null || true",
|
|
shell=True,
|
|
check=False)
|
|
except Exception as e:
|
|
print(f"Warning: Error during server cleanup: {e}")
|
|
return
|
|
|
|
# Run benchmarks based on execution plan
|
|
for concurrency_index in self.execution_plan[test_case_id]:
|
|
concurrency, iteration = test_case['concurrency_iterations'][
|
|
concurrency_index]
|
|
should_continue = self.run_benchmark(test_case, concurrency,
|
|
iteration, MODEL,
|
|
server_log_filename)
|
|
|
|
# If run_benchmark returns False, skip the entire test case
|
|
if not should_continue:
|
|
print(
|
|
f"RuntimeError detected - skipping remaining concurrencies for test case {test_case_id}"
|
|
)
|
|
break
|
|
|
|
finally:
|
|
# Cleanup: Kill server process using shell commands like in the original bash script
|
|
print(f"Stopping server for {model_label}")
|
|
try:
|
|
# Use shell commands for more reliable process killing
|
|
subprocess.run(f"kill -9 {server_process.pid}",
|
|
shell=True,
|
|
check=False)
|
|
subprocess.run(f"wait {server_process.pid} 2>/dev/null || true",
|
|
shell=True,
|
|
check=False)
|
|
except Exception as e:
|
|
print(f"Warning: Error during server cleanup: {e}")
|
|
|
|
time.sleep(5) # Give it time to clean up resources
|
|
print(f"Benchmark completed for {model_label}")
|
|
print()
|
|
|
|
def run_benchmarks(self) -> None:
|
|
"""Main function to run all benchmarks from config file"""
|
|
script_start_time = time.time()
|
|
|
|
print(f"Using config file: {self.config_file}")
|
|
if self.select_pattern:
|
|
print(f"Select pattern: {self.select_pattern}")
|
|
else:
|
|
print("Select pattern: default (all test cases)")
|
|
if self.skip_pattern:
|
|
print(f"Skip pattern: {self.skip_pattern}")
|
|
else:
|
|
print("Skip pattern: default (no skipping)")
|
|
|
|
# Load configuration
|
|
with open(self.config_file, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
test_cases = config['test_cases']
|
|
|
|
# Build execution plan
|
|
self.build_execution_plan(test_cases)
|
|
|
|
# Print execution plan before starting benchmarks
|
|
self.print_execution_plan(test_cases)
|
|
|
|
# Run each test case based on execution plan
|
|
for i, test_case in enumerate(test_cases, 1):
|
|
test_case_id = test_case['id']
|
|
|
|
if test_case_id not in self.execution_plan:
|
|
print("=" * 57)
|
|
print(
|
|
f"Test case {i}/{len(test_cases)} (ID: {test_case_id}): {test_case['model']} - SKIPPED"
|
|
)
|
|
print("=" * 57)
|
|
continue
|
|
|
|
print("=" * 57)
|
|
print(
|
|
f"Test case {i}/{len(test_cases)} (ID: {test_case_id}): {test_case['model']}"
|
|
)
|
|
print(
|
|
f"Config: GPUs={test_case['gpus']}, TP={test_case['tp']}, EP={test_case['ep']}, attn_backend={test_case['attn_backend']}, moe_backend={test_case['moe_backend']}"
|
|
)
|
|
print("=" * 57)
|
|
|
|
self.run_test_case(test_case)
|
|
|
|
# Calculate and display total script runtime
|
|
script_total_time = time.time() - script_start_time
|
|
hours = int(script_total_time // 3600)
|
|
minutes = int((script_total_time % 3600) // 60)
|
|
seconds = int(script_total_time % 60)
|
|
|
|
print("=" * 80)
|
|
print("SCRIPT COMPLETION SUMMARY")
|
|
print("=" * 80)
|
|
print(
|
|
f"Total script runtime: {hours:02d}:{minutes:02d}:{seconds:02d} (HH:MM:SS)"
|
|
)
|
|
print(f"Total runtime in seconds: {script_total_time:.2f}")
|
|
print("=" * 80)
|
|
print("All benchmarks completed!")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Run benchmarks from YAML configuration file')
|
|
parser.add_argument('--output_folder',
|
|
required=True,
|
|
help='Output folder for benchmark results')
|
|
parser.add_argument('--config_file',
|
|
required=True,
|
|
help='Path to YAML configuration file')
|
|
parser.add_argument(
|
|
'--skip',
|
|
help=
|
|
'Skip pattern: "2,4-1" means skip test case 2 and test case 4\'s 1st concurrency'
|
|
)
|
|
parser.add_argument(
|
|
'--select',
|
|
help=
|
|
'Select pattern: "1,3,5" means only run test cases 1, 3, and 5; "1-1,2-3" means only run test case 1\'s 1st concurrency and test case 2\'s 3rd concurrency'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
subprocess.run(f'echo "TRT-LLM GIT COMMIT": $TRT_LLM_GIT_COMMIT',
|
|
shell=True,
|
|
check=True)
|
|
except subprocess.CalledProcessError:
|
|
print("Warning: Could not echo TRT-LLM GIT COMMIT")
|
|
|
|
if not os.path.exists(args.config_file):
|
|
print(f"Error: Config file '{args.config_file}' does not exist")
|
|
sys.exit(1)
|
|
|
|
if not os.path.exists(args.output_folder):
|
|
print(f"Error: Output folder '{args.output_folder}' does not exist")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
runner = BenchmarkRunner(args.output_folder, args.config_file,
|
|
args.skip, args.select)
|
|
runner.run_benchmarks()
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|