[None][feat] Add request timing breakdown option in benchmark_serving (#8128)

Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
Yilin Fan 2025-10-10 09:24:54 -07:00 committed by GitHub
parent 85f157f389
commit 2695d70d42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1353 additions and 0 deletions

View File

@ -45,6 +45,7 @@ from tensorrt_llm.serve.scripts.benchmark_dataset import (
SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset)
from tensorrt_llm.serve.scripts.benchmark_utils import (
convert_to_pytorch_benchmark_format, write_to_json)
from tensorrt_llm.serve.scripts.time_breakdown import RequestTimeBreakdown
# isort: on
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -598,6 +599,34 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
write_to_json(pt_file, pt_records)
async def fetch_perf_metrics(base_url: str) -> dict:
"""
Fetch performance metrics from the /perf_metrics endpoint.
Args:
base_url: The base URL of the server
Returns:
Dictionary containing the performance metrics
"""
perf_url = f"{base_url}/perf_metrics"
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
try:
async with session.get(perf_url) as response:
if response.status == 200:
return await response.json()
else:
print(
f"Failed to fetch performance metrics. Status: {response.status}"
)
return {}
except Exception as e:
print(f"Error fetching performance metrics: {e}")
return {}
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
@ -877,6 +906,55 @@ def main(args: argparse.Namespace):
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)
# Save per-request breakdown if requested
if args.save_request_time_breakdown:
print("Fetching request performance metrics...")
perf_metrics = asyncio.run(fetch_perf_metrics(base_url))
if perf_metrics:
# Generate filename for perf metrics
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None else "")
perf_filename = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}-perf_metrics.json"
if args.result_dir:
perf_filename = os.path.join(args.result_dir, perf_filename)
# Save perf metrics to JSON file
with open(perf_filename, "w", encoding='utf-8') as outfile:
try:
json.dump(perf_metrics, outfile, indent=2)
except Exception as e:
print(f"Failed to save perf metrics: {e}")
print(f"Request performance metrics saved to: {perf_filename}")
# Create timing diagram from the saved JSON file
try:
analyzer = RequestTimeBreakdown()
print("Creating time diagram from request time breakdown...")
timing_data = analyzer.parse_json_file(perf_filename)
if timing_data:
# Generate HTML filename for the timing diagram
diagram_filename = f"{os.path.splitext(perf_filename)[0]}-time_diagram.html"
analyzer.create_timing_diagram(timing_data,
diagram_filename)
print(f"Time diagram saved to: {diagram_filename}")
else:
print(
"No time data found in request time breakdown - skipping diagram creation."
)
except Exception as e:
print(f"Failed to create time diagram: {e}")
print("Performance metrics were still saved successfully.")
else:
print("Failed to fetch per-request performance metrics.")
if __name__ == "__main__":
parser = FlexibleArgumentParser(
@ -1260,6 +1338,13 @@ if __name__ == "__main__":
help="Skip initial test run with a single prompt.",
)
parser.add_argument(
"--save-request-time-breakdown",
action="store_true",
help=
"After benchmarking, call the /perf_metric endpoint, save the result as JSON, and create an interactive time breakdown diagram.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,181 @@
# Time Breakdown Tool
A standalone tool for analyzing and visualizing TensorRT-LLM server request time breakdown.
## Overview
The Time Breakdown tool analyzes performance metrics from TensorRT-LLM servers and creates interactive visualizations showing how time is spent processing each request. It supports both aggregated and disaggregated server configurations.
The tool generates:
1. **Interactive HTML Diagram**: A stacked bar chart showing timing breakdown per request with hover tooltips
2. **Statistics**: Median times for each timing segment (optional)
### Example Visualization
![Request Time Breakdown Example](images/request_time_breakdown_example.png)
*Example of the interactive time diagram showing request time breakdown across different processing stages.*
### Timing Metrics
The tool aims to track detailed timing segments throughout the request lifecycle (currently we only track timing segments related to TTFT (Time-To-First-Token), full lifecycle tracking will be added soon):
#### Context/Prefill Stage Metrics
1. **Context Preprocessing** (`ctx_preprocessing`)
- **Time Period**: `server_arrival_time``arrival_time`
- **Description**: Python overhead & initialization when the context server receives the request
- **Includes**: Request parsing, pre-processing (e.g., tokenization) before queuing
2. **Context Queue** (`ctx_queue`)
- **Time Period**: `arrival_time``first_scheduled_time`
- **Description**: Time spent waiting in queue and resource allocation
- **Includes**: Queueing delay, memory allocation, scheduling wait time
3. **Context Processing** (`ctx_processing`)
- **Time Period**: `first_scheduled_time``first_token_time`
- **Description**: Actual prefill computation time
- **Includes**: Model forward pass for the context/prompt tokens
4. **Context Postprocessing** (`ctx_postprocessing`)
- **Time Period**: `first_token_time``server_first_token_time`
- **Description**: Time to prepare and send the first token response
- **Includes**: Response preparation, serialization, network overhead
#### Generation/Decode Stage Metrics (Disaggregated Mode Only)
5. **Generation Preprocessing** (`gen_preprocessing`)
- **Time Period**: `gen_server_arrival_time``gen_arrival_time`
- **Description**: Python overhead & initialization when generation server receives the request
- **Includes**: Request parsing, KV cache transfer preparation
6. **Generation Queue** (`gen_queue`)
- **Time Period**: `gen_arrival_time``gen_first_scheduled_time`
- **Description**: Time spent in queue and resource allocation, including KV cache transfer
- **Includes**:
Queueing delay, KV cache transfer, memory allocation for generation
7. **Generation First Token Postprocessing** (`gen_postprocessing`)
- **Time Period**: `gen_first_scheduled_time``gen_server_first_token_time`
- **Description**: Time to generate and send first token from generation server
- **Includes**: Token generation, response preparation
#### Disaggregation Server Metrics
8. **Disaggregation Preprocessing** (`disagg_preprocessing`)
- **Time Period**: `disagg_server_arrival_time``ctx_server_arrival_time`
- **Description**: Routing overhead from disagg server to context server
- **Includes**: Request forwarding, network latency
9. **Disaggregation Postprocessing** (`disagg_postprocessing`)
- **Time Period**: `gen_server_first_token_time``disagg_server_first_token_time`
- **Description**: Routing overhead from generation server back through disagg server
- **Includes**: Response forwarding, aggregation
## Input Format
The tool expects a JSON file containing an array of request performance metrics (unit: seconds).
### Aggregated Format
```json
[
{
"request_id": 0,
"perf_metrics": {
"timing_metrics": {
"server_arrival_time": 1.000,
"arrival_time": 1.002,
"first_scheduled_time": 1.005,
"first_token_time": 1.025,
"server_first_token_time": 1.027
}
}
}
]
```
### Disaggregated Format
```json
[
{
"ctx_perf_metrics": {
"request_id": 3,
"perf_metrics": {
"timing_metrics": {
"server_arrival_time": 2.000,
"arrival_time": 2.003,
"first_scheduled_time": 2.008,
"first_token_time": 2.035,
"server_first_token_time": 2.038
}
}
},
"gen_perf_metrics": {
"perf_metrics": {
"timing_metrics": {
"server_arrival_time": 2.050,
"arrival_time": 2.052,
"first_scheduled_time": 2.055,
"first_token_time": 2.080,
"server_first_token_time": 2.083
}
}
},
"disagg_server_arrival_time": 1.995,
"disagg_server_first_token_time": 2.090
}
]
```
## Usage
### Integration with Benchmark Serving
Step 1:
Set
```
return_perf_metrics: True
perf_metrics_max_requests: <INTEGER>
```
in the `extra-llm-api-config.yaml`. If you are running disaggregated serving, you should add configs for all servers (disagg, context and generation server).
Step 2:
Add `--save-request-time-breakdown` when running `benchmark_serving.py`
```
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model ${model_name} \
--dataset-name random \
--ignore-eos \
--num-prompts 1000 \
--random-input-len 1024 \
--random-output-len 2048 \
--random-ids \
--max-concurrency 64 \
--save-result \
--result-dir <RESULT_DIR> \
--percentile-metrics "ttft,tpot,itl,e2e" \
--save-request-time-breakdown
```
You will be able find the interactive time diagram in `<RESULT_DIR>`.
### As a CLI Tool
Step 1:
Query the perf_metrics.json using the `/perf_metrics` endpoint of the trtllm server (in case of disaggreated serving, you only need to query the disagg server). Make sure the servers have `perf_metrics_max_requests` and `return_perf_metric` configured.
```
curl -o perf_metrics.json <HOST>:<PORT>/perf_metrics
```
Step 2:
Process the `perf_metrics.json` with `time_breakdown.py`
```bash
# Basic usage - analyze and create time diagram
python time_breakdown.py perf_metrics.json
# Specify custom output file
python time_breakdown.py perf_metrics.json -o my_time_diagram.html
# Show statistics only (no diagram)
python time_breakdown.py perf_metrics.json --stats-only
# Create diagram and show statistics
python time_breakdown.py perf_metrics.json --show-stats
```

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Time Breakdown Analysis Tool
This module provides tools for analyzing and visualizing request time breakdown
from TensorRT-LLM server performance metrics.
"""
from .time_breakdown import (RequestDataParser, RequestTimeBreakdown,
TimingMetric, TimingMetricsConfig, main)
__all__ = [
'TimingMetric',
'TimingMetricsConfig',
'RequestDataParser',
'RequestTimeBreakdown',
'main',
]

View File

@ -0,0 +1,13 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Entry point for running time_breakdown as a module.
Usage:
python -m tensorrt_llm.serve.scripts.time_breakdown perf_metrics.json [options]
"""
from .time_breakdown import main
if __name__ == '__main__':
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

View File

@ -0,0 +1,550 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Time Breakdown Analysis Tool
This module provides tools for analyzing and visualizing request time breakdown
from TensorRT-LLM server performance metrics. It can be used both as a library
and as a standalone CLI tool.
Usage as CLI:
python time_breakdown.py <json_file> [options]
Usage as library:
from time_breakdown import RequestTimeBreakdown
analyzer = RequestTimeBreakdown()
timing_data = analyzer.parse_json_file("perf_metrics.json")
analyzer.create_timing_diagram(timing_data, "output.html")
"""
import argparse
import json
import math
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
import plotly.graph_objects as go
import plotly.offline as pyo
@dataclass
class TimingMetric:
"""Configuration for a timing metric segment."""
name: str
display_name: str
color: str
description: str
start_field: str
end_field: str
server_type: Optional[
str] = None # 'ctx', 'gen', 'disagg', or None for direct calculation
def calculate_duration(self, timing_data: Dict[str, float]) -> float:
"""Calculate the duration for this metric from timing data."""
start_time = timing_data.get(self.start_field, float('nan'))
end_time = timing_data.get(self.end_field, float('nan'))
# If either timestamp is NaN (not available), return NaN duration
if math.isnan(start_time) or math.isnan(end_time):
print(f"Warning: {self.name} has NaN start or end time")
return 0
if start_time > end_time:
print(f"Warning: {self.name} has start time after end time")
return 0
return end_time - start_time
class TimingMetricsConfig:
"""Configuration class that defines all available timing metrics."""
def __init__(self):
self.metrics = [
TimingMetric(
name='disagg_preprocessing',
display_name='Disagg Preprocessing',
color='#B8B8B8', # Light gray
description=
'Time duration from the disagg server receives the request to a context server receives it',
start_field='disagg_server_arrival_time',
end_field='ctx_server_arrival_time',
server_type='disagg'),
TimingMetric(
name='ctx_preprocessing',
display_name='Context Preprocessing',
color='#90EE90', # Light green
description=
'Time duration from a context server receives the request to a LLM worker queues it',
start_field='ctx_server_arrival_time',
end_field='ctx_arrival_time',
server_type='ctx'),
TimingMetric(
name='ctx_queue',
display_name='Context Queue',
color='#FFB347', # Light orange
description=
'Time duration from the request is queued to first scheduled',
start_field='ctx_arrival_time',
end_field='ctx_first_scheduled_time',
server_type='ctx'),
TimingMetric(
name='ctx_processing',
display_name='Context Processing',
color='#6495ED', # Cornflower blue
description=
'Time duration from first scheduled to first token generated on a LLM worker',
start_field='ctx_first_scheduled_time',
end_field='ctx_first_token_time',
server_type='ctx'),
TimingMetric(
name='ctx_postprocessing',
display_name='Context Postprocessing',
color='#DDA0DD', # Plum
description=
'Time duration from first token generated on a LLM worker to the first token response sent by the context server',
start_field='ctx_first_token_time',
end_field='ctx_server_first_token_time',
server_type='ctx'),
TimingMetric(
name='gen_preprocessing',
display_name='Generation Preprocessing',
color='#FFE66D', # Bright yellow
description=
'Time duration from a generation server receives the request to a LLM worker receives it',
start_field='gen_server_arrival_time',
end_field='gen_arrival_time',
server_type='gen'),
TimingMetric(
name='gen_queue',
display_name='Generation Queue',
color='#FF6B6B', # Coral red
description=
'Time duration from the request is queued to first scheduled',
start_field='gen_arrival_time',
end_field='gen_first_scheduled_time',
server_type='gen'),
TimingMetric(
name='gen_postprocessing',
display_name='Generation Postprocessing',
color='#95E1D3', # Mint/teal
description=
'Time duration from first scheduled to the first token response sent by the generation server',
start_field='gen_first_scheduled_time',
end_field='gen_server_first_token_time',
server_type='gen'),
TimingMetric(
name='disagg_postprocessing',
display_name='Disagg Postprocessing',
color='#A9A9A9', # Dark gray
description=
'Time duration from the first token response sent by the generation server to sent by the disagg server',
start_field='gen_server_first_token_time',
end_field='disagg_server_first_token_time',
server_type='disagg')
]
def get_metric_by_name(self, name: str) -> Optional[TimingMetric]:
"""Get a metric by its name."""
return next((m for m in self.metrics if m.name == name), None)
def get_metrics_by_server(self, server_type: str) -> List[TimingMetric]:
"""Get all metrics for a specific server type."""
return [m for m in self.metrics if m.server_type == server_type]
def add_metric(self, metric: TimingMetric):
"""Add a new timing metric."""
self.metrics.append(metric)
def remove_metric(self, name: str):
"""Remove a timing metric by name."""
self.metrics = [m for m in self.metrics if m.name != name]
class RequestDataParser:
"""Parser for disaggregated format with ctx_perf_metrics and gen_perf_metrics."""
def parse_request(self, request_data: Dict,
request_index: int) -> Dict[str, Any]:
is_disaggregated = 'ctx_perf_metrics' in request_data and 'gen_perf_metrics' in request_data
ctx_metrics = {}
gen_metrics = {}
if is_disaggregated:
ctx_metrics = request_data.get('ctx_perf_metrics', {}).get(
'perf_metrics', {}).get('timing_metrics', {})
gen_metrics = request_data.get('gen_perf_metrics', {}).get(
'perf_metrics', {}).get('timing_metrics', {})
else:
ctx_metrics = request_data.get('perf_metrics',
{}).get('timing_metrics', {})
ctx_arrival_time = ctx_metrics.get('arrival_time', 0)
ctx_first_scheduled_time = ctx_metrics.get('first_scheduled_time', 0)
ctx_first_token_time = ctx_metrics.get('first_token_time', 0)
ctx_server_arrival_time = ctx_metrics.get('server_arrival_time', 0)
ctx_server_first_token_time = ctx_metrics.get('server_first_token_time',
0)
gen_server_first_token_time = gen_metrics.get('server_first_token_time',
0)
gen_server_arrival_time = gen_metrics.get('server_arrival_time', 0)
gen_arrival_time = gen_metrics.get('arrival_time', 0)
gen_first_token_time = gen_metrics.get('first_token_time', 0)
gen_first_scheduled_time = gen_metrics.get('first_scheduled_time', 0)
disagg_server_arrival_time = 0
disagg_server_first_token_time = 0
if is_disaggregated:
disagg_server_arrival_time = request_data.get(
'disagg_server_arrival_time', 0)
disagg_server_first_token_time = request_data.get(
'disagg_server_first_token_time', 0)
# Get request ID
if is_disaggregated:
request_id = request_data.get('ctx_perf_metrics',
{}).get('request_id', request_index)
else:
request_id = request_data.get('request_id', request_index)
return {
'request_index': request_id,
'ctx_server_arrival_time': ctx_server_arrival_time,
'ctx_arrival_time': ctx_arrival_time,
'ctx_first_scheduled_time': ctx_first_scheduled_time,
'ctx_first_token_time': ctx_first_token_time,
'ctx_server_first_token_time': ctx_server_first_token_time,
'gen_server_arrival_time': gen_server_arrival_time,
'gen_arrival_time': gen_arrival_time,
'gen_first_scheduled_time': gen_first_scheduled_time,
'gen_first_token_time': gen_first_token_time,
'gen_server_first_token_time': gen_server_first_token_time,
'disagg_server_arrival_time': disagg_server_arrival_time,
'disagg_server_first_token_time': disagg_server_first_token_time,
}
class RequestTimeBreakdown:
"""Main class for analyzing request time breakdown."""
def __init__(self, config: Optional[TimingMetricsConfig] = None):
self.config = config or TimingMetricsConfig()
self.parser = RequestDataParser()
def parse_json_file(self, json_file_path: str) -> List[Dict]:
"""Parse JSON performance metrics file and extract timing information."""
try:
with open(json_file_path, 'r') as f:
data = json.load(f)
except FileNotFoundError:
print(f"Error: File '{json_file_path}' not found.")
sys.exit(1)
except json.JSONDecodeError as e:
print(f"Error parsing JSON file '{json_file_path}': {e}")
sys.exit(1)
except Exception as e:
print(f"Error reading file '{json_file_path}': {e}")
sys.exit(1)
timing_data = []
for i, request in enumerate(data):
parsed_data = self.parser.parse_request(request, i)
# Calculate durations for each metric
for metric in self.config.metrics:
duration = metric.calculate_duration(parsed_data)
parsed_data[f'{metric.name}_time'] = duration
timing_data.append(parsed_data)
if timing_data:
has_gen_metrics = any(entry['gen_server_arrival_time'] > 0
for entry in timing_data)
format_type = "disaggregated " if has_gen_metrics else "aggregated"
print(
f"Parsed timing data for {len(timing_data)} requests from {json_file_path} ({format_type} format)"
)
else:
print(f"Parsed timing data for 0 requests from {json_file_path}")
return timing_data
def create_timing_diagram(self,
timing_data: List[Dict],
output_file: Optional[str] = None):
"""Create an interactive HTML stacked bar chart showing time breakdown."""
if not timing_data:
print("No timing data to visualize.")
return
# Extract data for plotting
request_indices = [data['request_index'] for data in timing_data]
# Create the interactive plot
fig = go.Figure()
# Add traces for each metric
for metric in self.config.metrics:
times_ms = [
data.get(f'{metric.name}_time', 0) * 1000
for data in timing_data
]
# Only add trace if there's some non-zero data
if any(t > 0 for t in times_ms):
fig.add_trace(
go.Bar(
x=request_indices,
y=times_ms,
name=metric.display_name,
marker_color=metric.color,
hovertemplate=
f'<b>Request %{{x}}</b><br>{metric.display_name}: %{{y:.2f}} ms<extra></extra>'
))
# Update layout
fig.update_layout(
barmode='stack',
title={
'text':
'Request Time Breakdown<br><sub>Time Spent in Each Segment (Interactive)</sub>',
'x': 0.5,
'xanchor': 'center',
'font': {
'size': 16
}
},
xaxis_title='Request Index',
yaxis_title='Time (milliseconds)',
hovermode='x unified',
legend=dict(orientation="v",
yanchor="top",
y=1,
xanchor="left",
x=1.02),
width=1200,
height=700,
margin=dict(r=200))
# Calculate and add statistics
self._add_statistics_annotation(fig, timing_data)
# Set default output filename if not provided
if not output_file:
output_file = 'time_breakdown.html'
elif not output_file.endswith('.html'):
output_file += '.html'
# Generate the plotly div
plot_div = pyo.plot(fig,
output_type='div',
include_plotlyjs='cdn',
auto_open=False)
# Generate descriptions HTML
descriptions_html = self._generate_descriptions_html(timing_data)
# Combine into full HTML
full_html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Request Timing Breakdown</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
}}
.descriptions {{
margin-top: 30px;
padding: 20px;
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 5px;
max-width: 1200px;
}}
.descriptions h2 {{
margin-top: 0;
color: #333;
}}
.metric-desc {{
margin-bottom: 15px;
line-height: 1.6;
}}
.metric-name {{
font-weight: bold;
color: #2c3e50;
}}
</style>
</head>
<body>
{plot_div}
{descriptions_html}
</body>
</html>
"""
# Write to file
with open(output_file, 'w') as f:
f.write(full_html)
print(f"Interactive time breakdown diagram saved to: {output_file}")
print(f"Open the file in your web browser to interact with the chart!")
def _add_statistics_annotation(self, fig, timing_data: List[Dict]):
"""Add statistics annotation to the plot."""
# Calculate median times for each metric
stats_lines = ['<b>Median Times (ms):</b>']
total_times = []
for metric in self.config.metrics:
times = [
data.get(f'{metric.name}_time', 0) * 1000
for data in timing_data
]
if any(t > 0 for t in times):
median_time = np.median(times)
stats_lines.append(f'{metric.display_name}: {median_time:.2f}')
# Calculate total time per request
for data in timing_data:
total = sum(
data.get(f'{metric.name}_time', 0) * 1000
for metric in self.config.metrics)
total_times.append(total)
if total_times:
median_total = np.median(total_times)
stats_lines.append(f'<b>Total per Request: {median_total:.2f}</b>')
stats_lines.append(f'<b>Requests: {len(timing_data)}</b>')
stats_text = '<br>'.join(stats_lines)
fig.add_annotation(x=0.98,
y=0.98,
xref='paper',
yref='paper',
text=stats_text,
showarrow=False,
align='right',
bgcolor='rgba(255, 255, 255, 0.8)',
bordercolor='black',
borderwidth=1,
font=dict(size=10))
def _generate_descriptions_html(self, timing_data: List[Dict]) -> str:
"""Generate HTML for metric descriptions section."""
desc_items = []
for metric in self.config.metrics:
times = [
data.get(f'{metric.name}_time', 0) * 1000
for data in timing_data
]
# Only include metrics that have non-zero data
if any(t > 0 for t in times):
desc_items.append(
f'<div class="metric-desc">'
f'<span class="metric-name">{metric.display_name}:</span> '
f'{metric.description}'
f'</div>')
if not desc_items:
return ''
descriptions_html = f"""
<div class="descriptions">
<h2>Metric Descriptions</h2>
{''.join(desc_items)}
Reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serve/scripts/time_breakdown/README.md
</div>
"""
return descriptions_html
def show_statistics(self, timing_data: List[Dict]):
"""Show detailed statistics about the timing data."""
if not timing_data:
print("No timing data to analyze.")
return
print("\n=== Timing Statistics ===")
print(f"Total requests: {len(timing_data)}")
for metric in self.config.metrics:
times = [data.get(f'{metric.name}_time', 0) for data in timing_data]
if any(t > 0 for t in times):
print(f"\n{metric.display_name} Times (seconds):")
print(f" Range: {min(times):.3f} to {max(times):.3f}")
print(f" Median: {np.median(times):.3f}")
print(f" Description: {metric.description}")
def main():
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
description='Analyze and visualize TensorRT-LLM server time breakdown',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Analyze performance metrics and create timing diagram
python time_breakdown.py perf_metrics.json
# Specify custom output file
python time_breakdown.py perf_metrics.json -o my_timing.html
# Show statistics only (no diagram)
python time_breakdown.py perf_metrics.json --stats-only
# Create diagram and show statistics
python time_breakdown.py perf_metrics.json --show-stats
""")
parser.add_argument('json_file',
type=str,
help='Path to the JSON performance metrics file')
parser.add_argument(
'-o',
'--output',
type=str,
default=None,
help='Output HTML file path (default: time_breakdown.html)')
parser.add_argument('--stats-only',
action='store_true',
help='Show statistics only without creating diagram')
parser.add_argument('--show-stats',
action='store_true',
help='Show statistics in addition to creating diagram')
args = parser.parse_args()
# Create analyzer
analyzer = RequestTimeBreakdown()
# Parse the JSON file
print(f"Parsing timing data from: {args.json_file}")
timing_data = analyzer.parse_json_file(args.json_file)
if not timing_data:
print("No timing data found in the file.")
sys.exit(1)
# Show statistics if requested
if args.stats_only or args.show_stats:
analyzer.show_statistics(timing_data)
# Create diagram unless stats-only mode
if not args.stats_only:
analyzer.create_timing_diagram(timing_data, args.output)
if __name__ == '__main__':
main()

View File

@ -20,6 +20,7 @@ l0_a10:
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
- unittest/others/test_time_breakdown.py
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]

View File

@ -0,0 +1,504 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for time_breakdown module
Run tests with:
python -m pytest tests/unittest/others/test_time_breakdown.py -v
or
python -m unittest tests.unittest.others.test_time_breakdown
"""
import json
import os
import tempfile
import unittest
from unittest.mock import patch
from tensorrt_llm.serve.scripts.time_breakdown import (RequestDataParser,
RequestTimeBreakdown,
TimingMetric,
TimingMetricsConfig)
class TestTimingMetric(unittest.TestCase):
"""Test TimingMetric class."""
def test_timing_metric_creation(self):
"""Test basic TimingMetric creation."""
metric = TimingMetric(name='test_metric',
display_name='Test Metric',
color='blue',
description='Test description',
start_field='start_time',
end_field='end_time',
server_type='ctx')
self.assertEqual(metric.name, 'test_metric')
self.assertEqual(metric.display_name, 'Test Metric')
self.assertEqual(metric.color, 'blue')
self.assertEqual(metric.description, 'Test description')
self.assertEqual(metric.start_field, 'start_time')
self.assertEqual(metric.end_field, 'end_time')
self.assertEqual(metric.server_type, 'ctx')
def test_calculate_duration_valid(self):
"""Test duration calculation with valid timestamps."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 1.0, 'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 2.5)
def test_calculate_duration_missing_start(self):
"""Test duration calculation with missing start time."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
def test_calculate_duration_missing_end(self):
"""Test duration calculation with missing end time."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 1.0, 'end_time': 0}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
def test_calculate_duration_negative(self):
"""Test duration calculation doesn't produce negative values."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 5.0, 'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
class TestTimingMetricsConfig(unittest.TestCase):
"""Test TimingMetricsConfig class."""
def test_default_metrics_loaded(self):
"""Test that default metrics are loaded."""
config = TimingMetricsConfig()
# Should have multiple default metrics
self.assertGreater(len(config.metrics), 0)
# Check for expected metric names
metric_names = [m.name for m in config.metrics]
self.assertIn('ctx_preprocessing', metric_names)
self.assertIn('ctx_processing', metric_names)
def test_get_metric_by_name(self):
"""Test retrieving a metric by name."""
config = TimingMetricsConfig()
metric = config.get_metric_by_name('ctx_preprocessing')
self.assertIsNotNone(metric)
self.assertEqual(metric.name, 'ctx_preprocessing')
# Test non-existent metric
metric = config.get_metric_by_name('non_existent')
self.assertIsNone(metric)
def test_get_metrics_by_server(self):
"""Test retrieving metrics by server type."""
config = TimingMetricsConfig()
ctx_metrics = config.get_metrics_by_server('ctx')
self.assertGreater(len(ctx_metrics), 0)
# All returned metrics should be for 'ctx' server
for metric in ctx_metrics:
self.assertEqual(metric.server_type, 'ctx')
def test_add_metric(self):
"""Test adding a new metric."""
config = TimingMetricsConfig()
initial_count = len(config.metrics)
new_metric = TimingMetric(name='custom_metric',
display_name='Custom Metric',
color='red',
description='Custom test metric',
start_field='start',
end_field='end')
config.add_metric(new_metric)
self.assertEqual(len(config.metrics), initial_count + 1)
self.assertIsNotNone(config.get_metric_by_name('custom_metric'))
def test_remove_metric(self):
"""Test removing a metric."""
config = TimingMetricsConfig()
initial_count = len(config.metrics)
# Add a test metric first
test_metric = TimingMetric(name='test_to_remove',
display_name='Test',
color='blue',
description='Test',
start_field='start',
end_field='end')
config.add_metric(test_metric)
# Remove it
config.remove_metric('test_to_remove')
self.assertEqual(len(config.metrics), initial_count)
self.assertIsNone(config.get_metric_by_name('test_to_remove'))
class TestRequestDataParser(unittest.TestCase):
"""Test RequestDataParser class."""
def test_parse_aggregated_format(self):
"""Test parsing aggregated (non-disaggregated) format."""
parser = RequestDataParser()
request_data = {
'request_id': 'req123',
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
}
parsed = parser.parse_request(request_data, 0)
self.assertEqual(parsed['request_index'], 'req123')
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
self.assertEqual(parsed['ctx_first_scheduled_time'], 1.2)
self.assertEqual(parsed['ctx_first_token_time'], 1.5)
self.assertEqual(parsed['ctx_server_first_token_time'], 1.6)
# Gen metrics should be 0 in aggregated format
self.assertEqual(parsed['gen_server_arrival_time'], 0)
self.assertEqual(parsed['disagg_server_arrival_time'], 0)
def test_parse_disaggregated_format(self):
"""Test parsing disaggregated format."""
parser = RequestDataParser()
request_data = {
'ctx_perf_metrics': {
'request_id': 'req456',
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
},
'gen_perf_metrics': {
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 2.0,
'arrival_time': 2.1,
'first_scheduled_time': 2.2,
'first_token_time': 2.5,
'server_first_token_time': 2.6
}
}
},
'disagg_server_arrival_time': 0.5,
'disagg_server_first_token_time': 3.0
}
parsed = parser.parse_request(request_data, 0)
self.assertEqual(parsed['request_index'], 'req456')
# Context metrics
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
# Generation metrics
self.assertEqual(parsed['gen_server_arrival_time'], 2.0)
self.assertEqual(parsed['gen_arrival_time'], 2.1)
# Disaggregation metrics
self.assertEqual(parsed['disagg_server_arrival_time'], 0.5)
self.assertEqual(parsed['disagg_server_first_token_time'], 3.0)
def test_parse_missing_fields(self):
"""Test parsing with missing fields (should default to 0)."""
parser = RequestDataParser()
request_data = {
'request_id': 'req789',
'perf_metrics': {
'timing_metrics': {}
}
}
parsed = parser.parse_request(request_data, 0)
# All timing fields should default to 0
self.assertEqual(parsed['ctx_server_arrival_time'], 0)
self.assertEqual(parsed['ctx_arrival_time'], 0)
self.assertEqual(parsed['gen_server_arrival_time'], 0)
def test_parse_uses_index_as_fallback(self):
"""Test that index is used when request_id is missing."""
parser = RequestDataParser()
request_data = {'perf_metrics': {'timing_metrics': {}}}
parsed = parser.parse_request(request_data, 42)
self.assertEqual(parsed['request_index'], 42)
class TestRequestTimeBreakdown(unittest.TestCase):
"""Test RequestTimeBreakdown class."""
def setUp(self):
"""Set up test fixtures."""
self.analyzer = RequestTimeBreakdown()
# Create a temporary JSON file for testing
self.test_data = [{
'request_id': 0,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
}, {
'request_id': 1,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 2.0,
'arrival_time': 2.1,
'first_scheduled_time': 2.3,
'first_token_time': 2.7,
'server_first_token_time': 2.8
}
}
}]
def test_parse_json_file(self):
"""Test parsing a JSON file."""
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
delete=False) as f:
json.dump(self.test_data, f)
temp_file = f.name
try:
timing_data = self.analyzer.parse_json_file(temp_file)
self.assertEqual(len(timing_data), 2)
# Check first request
self.assertEqual(timing_data[0]['request_index'], 0)
self.assertEqual(timing_data[0]['ctx_server_arrival_time'], 1.0)
# Check that durations were calculated
self.assertIn('ctx_preprocessing_time', timing_data[0])
self.assertIn('ctx_queue_time', timing_data[0])
# Verify a specific duration calculation
# ctx_preprocessing = ctx_arrival_time - ctx_server_arrival_time
expected_preprocessing = 1.1 - 1.0
self.assertAlmostEqual(timing_data[0]['ctx_preprocessing_time'],
expected_preprocessing,
places=5)
finally:
os.unlink(temp_file)
def test_parse_json_file_not_found(self):
"""Test parsing a non-existent file."""
with self.assertRaises(SystemExit):
self.analyzer.parse_json_file('non_existent_file.json')
def test_parse_json_file_invalid_json(self):
"""Test parsing an invalid JSON file."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
delete=False) as f:
f.write("{ invalid json")
temp_file = f.name
try:
with self.assertRaises(SystemExit):
self.analyzer.parse_json_file(temp_file)
finally:
os.unlink(temp_file)
def test_create_timing_diagram(self):
"""Test creating a timing diagram."""
# Create sample timing data
timing_data = [{
'request_index': 0,
'ctx_preprocessing_time': 0.1,
'ctx_queue_time': 0.2,
'ctx_processing_time': 0.3,
'ctx_postprocessing_time': 0.05,
'gen_preprocessing_time': 0,
'gen_queue_time': 0,
'gen_postprocessing_time': 0,
'disagg_preprocessing_time': 0,
'disagg_postprocessing_time': 0,
}, {
'request_index': 1,
'ctx_preprocessing_time': 0.15,
'ctx_queue_time': 0.25,
'ctx_processing_time': 0.35,
'ctx_postprocessing_time': 0.06,
'gen_preprocessing_time': 0,
'gen_queue_time': 0,
'gen_postprocessing_time': 0,
'disagg_preprocessing_time': 0,
'disagg_postprocessing_time': 0,
}]
with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as f:
temp_file = f.name
try:
# Mock plotly to avoid actual file creation
with patch(
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
) as mock_plot:
self.analyzer.create_timing_diagram(timing_data, temp_file)
# Verify that plot was called
mock_plot.assert_called_once()
finally:
if os.path.exists(temp_file):
os.unlink(temp_file)
def test_create_timing_diagram_empty_data(self):
"""Test creating a diagram with empty data."""
# Should handle gracefully without creating a file
with patch('builtins.print') as mock_print:
self.analyzer.create_timing_diagram([])
mock_print.assert_called_with("No timing data to visualize.")
def test_show_statistics(self):
"""Test showing statistics."""
timing_data = [{
'ctx_preprocessing_time': 0.1,
'ctx_queue_time': 0.2,
'ctx_processing_time': 0.3,
}, {
'ctx_preprocessing_time': 0.15,
'ctx_queue_time': 0.25,
'ctx_processing_time': 0.35,
}]
# Capture printed output
with patch('builtins.print') as mock_print:
self.analyzer.show_statistics(timing_data)
# Should have printed something
self.assertTrue(mock_print.called)
# Check for expected content in printed output
printed_output = ' '.join(
[str(call[0][0]) for call in mock_print.call_args_list])
self.assertIn('Total requests', printed_output)
def test_show_statistics_empty_data(self):
"""Test showing statistics with empty data."""
with patch('builtins.print') as mock_print:
self.analyzer.show_statistics([])
mock_print.assert_called_with("No timing data to analyze.")
def test_custom_config(self):
"""Test using a custom configuration."""
custom_config = TimingMetricsConfig()
custom_config.add_metric(
TimingMetric(name='custom_metric',
display_name='Custom',
color='red',
description='Custom metric',
start_field='custom_start',
end_field='custom_end'))
analyzer = RequestTimeBreakdown(config=custom_config)
# Verify custom config is used
self.assertIsNotNone(
analyzer.config.get_metric_by_name('custom_metric'))
class TestIntegration(unittest.TestCase):
"""Integration tests for the full workflow."""
def test_full_workflow(self):
"""Test the complete workflow from file to diagram."""
# Create test data
test_data = [{
'request_id': i,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': float(i),
'arrival_time': float(i) + 0.1,
'first_scheduled_time': float(i) + 0.2,
'first_token_time': float(i) + 0.5,
'server_first_token_time': float(i) + 0.6
}
}
} for i in range(5)]
# Create temporary files and run the complete workflow
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=True) as json_f, \
tempfile.NamedTemporaryFile(suffix='.html', delete=True) as html_f:
# Write test data to JSON file
json.dump(test_data, json_f)
json_f.flush() # Ensure data is written before reading
analyzer = RequestTimeBreakdown()
timing_data = analyzer.parse_json_file(json_f.name)
# Verify parsing
self.assertEqual(len(timing_data), 5)
# Mock the plot function to avoid actual file operations
with patch(
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
):
analyzer.create_timing_diagram(timing_data, html_f.name)
if __name__ == '__main__':
unittest.main()