mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Add request timing breakdown option in benchmark_serving (#8128)
Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com>
This commit is contained in:
parent
85f157f389
commit
2695d70d42
@ -45,6 +45,7 @@ from tensorrt_llm.serve.scripts.benchmark_dataset import (
|
||||
SampleRequest, ShareGPTDataset, SonnetDataset, VisionArenaDataset)
|
||||
from tensorrt_llm.serve.scripts.benchmark_utils import (
|
||||
convert_to_pytorch_benchmark_format, write_to_json)
|
||||
from tensorrt_llm.serve.scripts.time_breakdown import RequestTimeBreakdown
|
||||
# isort: on
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -598,6 +599,34 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
async def fetch_perf_metrics(base_url: str) -> dict:
|
||||
"""
|
||||
Fetch performance metrics from the /perf_metrics endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the server
|
||||
|
||||
Returns:
|
||||
Dictionary containing the performance metrics
|
||||
"""
|
||||
perf_url = f"{base_url}/perf_metrics"
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
try:
|
||||
async with session.get(perf_url) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
print(
|
||||
f"Failed to fetch performance metrics. Status: {response.status}"
|
||||
)
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"Error fetching performance metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
@ -877,6 +906,55 @@ def main(args: argparse.Namespace):
|
||||
json.dump(result_json, outfile)
|
||||
save_to_pytorch_benchmark_format(args, result_json, file_name)
|
||||
|
||||
# Save per-request breakdown if requested
|
||||
if args.save_request_time_breakdown:
|
||||
print("Fetching request performance metrics...")
|
||||
perf_metrics = asyncio.run(fetch_perf_metrics(base_url))
|
||||
|
||||
if perf_metrics:
|
||||
# Generate filename for perf metrics
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None else "")
|
||||
perf_filename = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}-perf_metrics.json"
|
||||
|
||||
if args.result_dir:
|
||||
perf_filename = os.path.join(args.result_dir, perf_filename)
|
||||
|
||||
# Save perf metrics to JSON file
|
||||
with open(perf_filename, "w", encoding='utf-8') as outfile:
|
||||
try:
|
||||
json.dump(perf_metrics, outfile, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Failed to save perf metrics: {e}")
|
||||
|
||||
print(f"Request performance metrics saved to: {perf_filename}")
|
||||
|
||||
# Create timing diagram from the saved JSON file
|
||||
try:
|
||||
analyzer = RequestTimeBreakdown()
|
||||
|
||||
print("Creating time diagram from request time breakdown...")
|
||||
timing_data = analyzer.parse_json_file(perf_filename)
|
||||
|
||||
if timing_data:
|
||||
# Generate HTML filename for the timing diagram
|
||||
diagram_filename = f"{os.path.splitext(perf_filename)[0]}-time_diagram.html"
|
||||
analyzer.create_timing_diagram(timing_data,
|
||||
diagram_filename)
|
||||
|
||||
print(f"Time diagram saved to: {diagram_filename}")
|
||||
else:
|
||||
print(
|
||||
"No time data found in request time breakdown - skipping diagram creation."
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to create time diagram: {e}")
|
||||
print("Performance metrics were still saved successfully.")
|
||||
else:
|
||||
print("Failed to fetch per-request performance metrics.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
@ -1260,6 +1338,13 @@ if __name__ == "__main__":
|
||||
help="Skip initial test run with a single prompt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-request-time-breakdown",
|
||||
action="store_true",
|
||||
help=
|
||||
"After benchmarking, call the /perf_metric endpoint, save the result as JSON, and create an interactive time breakdown diagram.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
181
tensorrt_llm/serve/scripts/time_breakdown/README.md
Normal file
181
tensorrt_llm/serve/scripts/time_breakdown/README.md
Normal file
@ -0,0 +1,181 @@
|
||||
# Time Breakdown Tool
|
||||
|
||||
A standalone tool for analyzing and visualizing TensorRT-LLM server request time breakdown.
|
||||
|
||||
## Overview
|
||||
|
||||
The Time Breakdown tool analyzes performance metrics from TensorRT-LLM servers and creates interactive visualizations showing how time is spent processing each request. It supports both aggregated and disaggregated server configurations.
|
||||
|
||||
|
||||
The tool generates:
|
||||
|
||||
1. **Interactive HTML Diagram**: A stacked bar chart showing timing breakdown per request with hover tooltips
|
||||
2. **Statistics**: Median times for each timing segment (optional)
|
||||
|
||||
### Example Visualization
|
||||
|
||||

|
||||
|
||||
*Example of the interactive time diagram showing request time breakdown across different processing stages.*
|
||||
|
||||
### Timing Metrics
|
||||
|
||||
The tool aims to track detailed timing segments throughout the request lifecycle (currently we only track timing segments related to TTFT (Time-To-First-Token), full lifecycle tracking will be added soon):
|
||||
|
||||
#### Context/Prefill Stage Metrics
|
||||
|
||||
1. **Context Preprocessing** (`ctx_preprocessing`)
|
||||
- **Time Period**: `server_arrival_time` → `arrival_time`
|
||||
- **Description**: Python overhead & initialization when the context server receives the request
|
||||
- **Includes**: Request parsing, pre-processing (e.g., tokenization) before queuing
|
||||
|
||||
2. **Context Queue** (`ctx_queue`)
|
||||
- **Time Period**: `arrival_time` → `first_scheduled_time`
|
||||
- **Description**: Time spent waiting in queue and resource allocation
|
||||
- **Includes**: Queueing delay, memory allocation, scheduling wait time
|
||||
|
||||
3. **Context Processing** (`ctx_processing`)
|
||||
- **Time Period**: `first_scheduled_time` → `first_token_time`
|
||||
- **Description**: Actual prefill computation time
|
||||
- **Includes**: Model forward pass for the context/prompt tokens
|
||||
|
||||
4. **Context Postprocessing** (`ctx_postprocessing`)
|
||||
- **Time Period**: `first_token_time` → `server_first_token_time`
|
||||
- **Description**: Time to prepare and send the first token response
|
||||
- **Includes**: Response preparation, serialization, network overhead
|
||||
|
||||
#### Generation/Decode Stage Metrics (Disaggregated Mode Only)
|
||||
|
||||
5. **Generation Preprocessing** (`gen_preprocessing`)
|
||||
- **Time Period**: `gen_server_arrival_time` → `gen_arrival_time`
|
||||
- **Description**: Python overhead & initialization when generation server receives the request
|
||||
- **Includes**: Request parsing, KV cache transfer preparation
|
||||
|
||||
6. **Generation Queue** (`gen_queue`)
|
||||
- **Time Period**: `gen_arrival_time` → `gen_first_scheduled_time`
|
||||
- **Description**: Time spent in queue and resource allocation, including KV cache transfer
|
||||
- **Includes**:
|
||||
Queueing delay, KV cache transfer, memory allocation for generation
|
||||
|
||||
7. **Generation First Token Postprocessing** (`gen_postprocessing`)
|
||||
- **Time Period**: `gen_first_scheduled_time` → `gen_server_first_token_time`
|
||||
- **Description**: Time to generate and send first token from generation server
|
||||
- **Includes**: Token generation, response preparation
|
||||
|
||||
#### Disaggregation Server Metrics
|
||||
|
||||
8. **Disaggregation Preprocessing** (`disagg_preprocessing`)
|
||||
- **Time Period**: `disagg_server_arrival_time` → `ctx_server_arrival_time`
|
||||
- **Description**: Routing overhead from disagg server to context server
|
||||
- **Includes**: Request forwarding, network latency
|
||||
|
||||
9. **Disaggregation Postprocessing** (`disagg_postprocessing`)
|
||||
- **Time Period**: `gen_server_first_token_time` → `disagg_server_first_token_time`
|
||||
- **Description**: Routing overhead from generation server back through disagg server
|
||||
- **Includes**: Response forwarding, aggregation
|
||||
## Input Format
|
||||
|
||||
The tool expects a JSON file containing an array of request performance metrics (unit: seconds).
|
||||
|
||||
### Aggregated Format
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"request_id": 0,
|
||||
"perf_metrics": {
|
||||
"timing_metrics": {
|
||||
"server_arrival_time": 1.000,
|
||||
"arrival_time": 1.002,
|
||||
"first_scheduled_time": 1.005,
|
||||
"first_token_time": 1.025,
|
||||
"server_first_token_time": 1.027
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Disaggregated Format
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"ctx_perf_metrics": {
|
||||
"request_id": 3,
|
||||
"perf_metrics": {
|
||||
"timing_metrics": {
|
||||
"server_arrival_time": 2.000,
|
||||
"arrival_time": 2.003,
|
||||
"first_scheduled_time": 2.008,
|
||||
"first_token_time": 2.035,
|
||||
"server_first_token_time": 2.038
|
||||
}
|
||||
}
|
||||
},
|
||||
"gen_perf_metrics": {
|
||||
"perf_metrics": {
|
||||
"timing_metrics": {
|
||||
"server_arrival_time": 2.050,
|
||||
"arrival_time": 2.052,
|
||||
"first_scheduled_time": 2.055,
|
||||
"first_token_time": 2.080,
|
||||
"server_first_token_time": 2.083
|
||||
}
|
||||
}
|
||||
},
|
||||
"disagg_server_arrival_time": 1.995,
|
||||
"disagg_server_first_token_time": 2.090
|
||||
}
|
||||
]
|
||||
```
|
||||
## Usage
|
||||
|
||||
### Integration with Benchmark Serving
|
||||
Step 1:
|
||||
Set
|
||||
```
|
||||
return_perf_metrics: True
|
||||
perf_metrics_max_requests: <INTEGER>
|
||||
```
|
||||
in the `extra-llm-api-config.yaml`. If you are running disaggregated serving, you should add configs for all servers (disagg, context and generation server).
|
||||
|
||||
Step 2:
|
||||
Add `--save-request-time-breakdown` when running `benchmark_serving.py`
|
||||
```
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model ${model_name} \
|
||||
--dataset-name random \
|
||||
--ignore-eos \
|
||||
--num-prompts 1000 \
|
||||
--random-input-len 1024 \
|
||||
--random-output-len 2048 \
|
||||
--random-ids \
|
||||
--max-concurrency 64 \
|
||||
--save-result \
|
||||
--result-dir <RESULT_DIR> \
|
||||
--percentile-metrics "ttft,tpot,itl,e2e" \
|
||||
--save-request-time-breakdown
|
||||
```
|
||||
You will be able find the interactive time diagram in `<RESULT_DIR>`.
|
||||
### As a CLI Tool
|
||||
Step 1:
|
||||
Query the perf_metrics.json using the `/perf_metrics` endpoint of the trtllm server (in case of disaggreated serving, you only need to query the disagg server). Make sure the servers have `perf_metrics_max_requests` and `return_perf_metric` configured.
|
||||
```
|
||||
curl -o perf_metrics.json <HOST>:<PORT>/perf_metrics
|
||||
```
|
||||
Step 2:
|
||||
Process the `perf_metrics.json` with `time_breakdown.py`
|
||||
```bash
|
||||
# Basic usage - analyze and create time diagram
|
||||
python time_breakdown.py perf_metrics.json
|
||||
|
||||
# Specify custom output file
|
||||
python time_breakdown.py perf_metrics.json -o my_time_diagram.html
|
||||
|
||||
# Show statistics only (no diagram)
|
||||
python time_breakdown.py perf_metrics.json --stats-only
|
||||
|
||||
# Create diagram and show statistics
|
||||
python time_breakdown.py perf_metrics.json --show-stats
|
||||
```
|
||||
19
tensorrt_llm/serve/scripts/time_breakdown/__init__.py
Normal file
19
tensorrt_llm/serve/scripts/time_breakdown/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Time Breakdown Analysis Tool
|
||||
|
||||
This module provides tools for analyzing and visualizing request time breakdown
|
||||
from TensorRT-LLM server performance metrics.
|
||||
"""
|
||||
|
||||
from .time_breakdown import (RequestDataParser, RequestTimeBreakdown,
|
||||
TimingMetric, TimingMetricsConfig, main)
|
||||
|
||||
__all__ = [
|
||||
'TimingMetric',
|
||||
'TimingMetricsConfig',
|
||||
'RequestDataParser',
|
||||
'RequestTimeBreakdown',
|
||||
'main',
|
||||
]
|
||||
13
tensorrt_llm/serve/scripts/time_breakdown/__main__.py
Normal file
13
tensorrt_llm/serve/scripts/time_breakdown/__main__.py
Normal file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Entry point for running time_breakdown as a module.
|
||||
|
||||
Usage:
|
||||
python -m tensorrt_llm.serve.scripts.time_breakdown perf_metrics.json [options]
|
||||
"""
|
||||
|
||||
from .time_breakdown import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 127 KiB |
550
tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py
Normal file
550
tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py
Normal file
@ -0,0 +1,550 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Time Breakdown Analysis Tool
|
||||
|
||||
This module provides tools for analyzing and visualizing request time breakdown
|
||||
from TensorRT-LLM server performance metrics. It can be used both as a library
|
||||
and as a standalone CLI tool.
|
||||
|
||||
Usage as CLI:
|
||||
python time_breakdown.py <json_file> [options]
|
||||
|
||||
Usage as library:
|
||||
from time_breakdown import RequestTimeBreakdown
|
||||
analyzer = RequestTimeBreakdown()
|
||||
timing_data = analyzer.parse_json_file("perf_metrics.json")
|
||||
analyzer.create_timing_diagram(timing_data, "output.html")
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import plotly.offline as pyo
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingMetric:
|
||||
"""Configuration for a timing metric segment."""
|
||||
name: str
|
||||
display_name: str
|
||||
color: str
|
||||
description: str
|
||||
start_field: str
|
||||
end_field: str
|
||||
server_type: Optional[
|
||||
str] = None # 'ctx', 'gen', 'disagg', or None for direct calculation
|
||||
|
||||
def calculate_duration(self, timing_data: Dict[str, float]) -> float:
|
||||
"""Calculate the duration for this metric from timing data."""
|
||||
start_time = timing_data.get(self.start_field, float('nan'))
|
||||
end_time = timing_data.get(self.end_field, float('nan'))
|
||||
|
||||
# If either timestamp is NaN (not available), return NaN duration
|
||||
if math.isnan(start_time) or math.isnan(end_time):
|
||||
print(f"Warning: {self.name} has NaN start or end time")
|
||||
return 0
|
||||
|
||||
if start_time > end_time:
|
||||
print(f"Warning: {self.name} has start time after end time")
|
||||
return 0
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
class TimingMetricsConfig:
|
||||
"""Configuration class that defines all available timing metrics."""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = [
|
||||
TimingMetric(
|
||||
name='disagg_preprocessing',
|
||||
display_name='Disagg Preprocessing',
|
||||
color='#B8B8B8', # Light gray
|
||||
description=
|
||||
'Time duration from the disagg server receives the request to a context server receives it',
|
||||
start_field='disagg_server_arrival_time',
|
||||
end_field='ctx_server_arrival_time',
|
||||
server_type='disagg'),
|
||||
TimingMetric(
|
||||
name='ctx_preprocessing',
|
||||
display_name='Context Preprocessing',
|
||||
color='#90EE90', # Light green
|
||||
description=
|
||||
'Time duration from a context server receives the request to a LLM worker queues it',
|
||||
start_field='ctx_server_arrival_time',
|
||||
end_field='ctx_arrival_time',
|
||||
server_type='ctx'),
|
||||
TimingMetric(
|
||||
name='ctx_queue',
|
||||
display_name='Context Queue',
|
||||
color='#FFB347', # Light orange
|
||||
description=
|
||||
'Time duration from the request is queued to first scheduled',
|
||||
start_field='ctx_arrival_time',
|
||||
end_field='ctx_first_scheduled_time',
|
||||
server_type='ctx'),
|
||||
TimingMetric(
|
||||
name='ctx_processing',
|
||||
display_name='Context Processing',
|
||||
color='#6495ED', # Cornflower blue
|
||||
description=
|
||||
'Time duration from first scheduled to first token generated on a LLM worker',
|
||||
start_field='ctx_first_scheduled_time',
|
||||
end_field='ctx_first_token_time',
|
||||
server_type='ctx'),
|
||||
TimingMetric(
|
||||
name='ctx_postprocessing',
|
||||
display_name='Context Postprocessing',
|
||||
color='#DDA0DD', # Plum
|
||||
description=
|
||||
'Time duration from first token generated on a LLM worker to the first token response sent by the context server',
|
||||
start_field='ctx_first_token_time',
|
||||
end_field='ctx_server_first_token_time',
|
||||
server_type='ctx'),
|
||||
TimingMetric(
|
||||
name='gen_preprocessing',
|
||||
display_name='Generation Preprocessing',
|
||||
color='#FFE66D', # Bright yellow
|
||||
description=
|
||||
'Time duration from a generation server receives the request to a LLM worker receives it',
|
||||
start_field='gen_server_arrival_time',
|
||||
end_field='gen_arrival_time',
|
||||
server_type='gen'),
|
||||
TimingMetric(
|
||||
name='gen_queue',
|
||||
display_name='Generation Queue',
|
||||
color='#FF6B6B', # Coral red
|
||||
description=
|
||||
'Time duration from the request is queued to first scheduled',
|
||||
start_field='gen_arrival_time',
|
||||
end_field='gen_first_scheduled_time',
|
||||
server_type='gen'),
|
||||
TimingMetric(
|
||||
name='gen_postprocessing',
|
||||
display_name='Generation Postprocessing',
|
||||
color='#95E1D3', # Mint/teal
|
||||
description=
|
||||
'Time duration from first scheduled to the first token response sent by the generation server',
|
||||
start_field='gen_first_scheduled_time',
|
||||
end_field='gen_server_first_token_time',
|
||||
server_type='gen'),
|
||||
TimingMetric(
|
||||
name='disagg_postprocessing',
|
||||
display_name='Disagg Postprocessing',
|
||||
color='#A9A9A9', # Dark gray
|
||||
description=
|
||||
'Time duration from the first token response sent by the generation server to sent by the disagg server',
|
||||
start_field='gen_server_first_token_time',
|
||||
end_field='disagg_server_first_token_time',
|
||||
server_type='disagg')
|
||||
]
|
||||
|
||||
def get_metric_by_name(self, name: str) -> Optional[TimingMetric]:
|
||||
"""Get a metric by its name."""
|
||||
return next((m for m in self.metrics if m.name == name), None)
|
||||
|
||||
def get_metrics_by_server(self, server_type: str) -> List[TimingMetric]:
|
||||
"""Get all metrics for a specific server type."""
|
||||
return [m for m in self.metrics if m.server_type == server_type]
|
||||
|
||||
def add_metric(self, metric: TimingMetric):
|
||||
"""Add a new timing metric."""
|
||||
self.metrics.append(metric)
|
||||
|
||||
def remove_metric(self, name: str):
|
||||
"""Remove a timing metric by name."""
|
||||
self.metrics = [m for m in self.metrics if m.name != name]
|
||||
|
||||
|
||||
class RequestDataParser:
|
||||
"""Parser for disaggregated format with ctx_perf_metrics and gen_perf_metrics."""
|
||||
|
||||
def parse_request(self, request_data: Dict,
|
||||
request_index: int) -> Dict[str, Any]:
|
||||
is_disaggregated = 'ctx_perf_metrics' in request_data and 'gen_perf_metrics' in request_data
|
||||
|
||||
ctx_metrics = {}
|
||||
gen_metrics = {}
|
||||
if is_disaggregated:
|
||||
ctx_metrics = request_data.get('ctx_perf_metrics', {}).get(
|
||||
'perf_metrics', {}).get('timing_metrics', {})
|
||||
gen_metrics = request_data.get('gen_perf_metrics', {}).get(
|
||||
'perf_metrics', {}).get('timing_metrics', {})
|
||||
else:
|
||||
ctx_metrics = request_data.get('perf_metrics',
|
||||
{}).get('timing_metrics', {})
|
||||
|
||||
ctx_arrival_time = ctx_metrics.get('arrival_time', 0)
|
||||
ctx_first_scheduled_time = ctx_metrics.get('first_scheduled_time', 0)
|
||||
ctx_first_token_time = ctx_metrics.get('first_token_time', 0)
|
||||
ctx_server_arrival_time = ctx_metrics.get('server_arrival_time', 0)
|
||||
ctx_server_first_token_time = ctx_metrics.get('server_first_token_time',
|
||||
0)
|
||||
|
||||
gen_server_first_token_time = gen_metrics.get('server_first_token_time',
|
||||
0)
|
||||
gen_server_arrival_time = gen_metrics.get('server_arrival_time', 0)
|
||||
gen_arrival_time = gen_metrics.get('arrival_time', 0)
|
||||
gen_first_token_time = gen_metrics.get('first_token_time', 0)
|
||||
gen_first_scheduled_time = gen_metrics.get('first_scheduled_time', 0)
|
||||
|
||||
disagg_server_arrival_time = 0
|
||||
disagg_server_first_token_time = 0
|
||||
if is_disaggregated:
|
||||
disagg_server_arrival_time = request_data.get(
|
||||
'disagg_server_arrival_time', 0)
|
||||
disagg_server_first_token_time = request_data.get(
|
||||
'disagg_server_first_token_time', 0)
|
||||
|
||||
# Get request ID
|
||||
if is_disaggregated:
|
||||
request_id = request_data.get('ctx_perf_metrics',
|
||||
{}).get('request_id', request_index)
|
||||
else:
|
||||
request_id = request_data.get('request_id', request_index)
|
||||
|
||||
return {
|
||||
'request_index': request_id,
|
||||
'ctx_server_arrival_time': ctx_server_arrival_time,
|
||||
'ctx_arrival_time': ctx_arrival_time,
|
||||
'ctx_first_scheduled_time': ctx_first_scheduled_time,
|
||||
'ctx_first_token_time': ctx_first_token_time,
|
||||
'ctx_server_first_token_time': ctx_server_first_token_time,
|
||||
'gen_server_arrival_time': gen_server_arrival_time,
|
||||
'gen_arrival_time': gen_arrival_time,
|
||||
'gen_first_scheduled_time': gen_first_scheduled_time,
|
||||
'gen_first_token_time': gen_first_token_time,
|
||||
'gen_server_first_token_time': gen_server_first_token_time,
|
||||
'disagg_server_arrival_time': disagg_server_arrival_time,
|
||||
'disagg_server_first_token_time': disagg_server_first_token_time,
|
||||
}
|
||||
|
||||
|
||||
class RequestTimeBreakdown:
|
||||
"""Main class for analyzing request time breakdown."""
|
||||
|
||||
def __init__(self, config: Optional[TimingMetricsConfig] = None):
|
||||
self.config = config or TimingMetricsConfig()
|
||||
self.parser = RequestDataParser()
|
||||
|
||||
def parse_json_file(self, json_file_path: str) -> List[Dict]:
|
||||
"""Parse JSON performance metrics file and extract timing information."""
|
||||
try:
|
||||
with open(json_file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{json_file_path}' not found.")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON file '{json_file_path}': {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error reading file '{json_file_path}': {e}")
|
||||
sys.exit(1)
|
||||
|
||||
timing_data = []
|
||||
|
||||
for i, request in enumerate(data):
|
||||
parsed_data = self.parser.parse_request(request, i)
|
||||
|
||||
# Calculate durations for each metric
|
||||
for metric in self.config.metrics:
|
||||
duration = metric.calculate_duration(parsed_data)
|
||||
parsed_data[f'{metric.name}_time'] = duration
|
||||
|
||||
timing_data.append(parsed_data)
|
||||
|
||||
if timing_data:
|
||||
has_gen_metrics = any(entry['gen_server_arrival_time'] > 0
|
||||
for entry in timing_data)
|
||||
format_type = "disaggregated " if has_gen_metrics else "aggregated"
|
||||
print(
|
||||
f"Parsed timing data for {len(timing_data)} requests from {json_file_path} ({format_type} format)"
|
||||
)
|
||||
else:
|
||||
print(f"Parsed timing data for 0 requests from {json_file_path}")
|
||||
|
||||
return timing_data
|
||||
|
||||
def create_timing_diagram(self,
|
||||
timing_data: List[Dict],
|
||||
output_file: Optional[str] = None):
|
||||
"""Create an interactive HTML stacked bar chart showing time breakdown."""
|
||||
if not timing_data:
|
||||
print("No timing data to visualize.")
|
||||
return
|
||||
|
||||
# Extract data for plotting
|
||||
request_indices = [data['request_index'] for data in timing_data]
|
||||
|
||||
# Create the interactive plot
|
||||
fig = go.Figure()
|
||||
|
||||
# Add traces for each metric
|
||||
for metric in self.config.metrics:
|
||||
times_ms = [
|
||||
data.get(f'{metric.name}_time', 0) * 1000
|
||||
for data in timing_data
|
||||
]
|
||||
|
||||
# Only add trace if there's some non-zero data
|
||||
if any(t > 0 for t in times_ms):
|
||||
fig.add_trace(
|
||||
go.Bar(
|
||||
x=request_indices,
|
||||
y=times_ms,
|
||||
name=metric.display_name,
|
||||
marker_color=metric.color,
|
||||
hovertemplate=
|
||||
f'<b>Request %{{x}}</b><br>{metric.display_name}: %{{y:.2f}} ms<extra></extra>'
|
||||
))
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
barmode='stack',
|
||||
title={
|
||||
'text':
|
||||
'Request Time Breakdown<br><sub>Time Spent in Each Segment (Interactive)</sub>',
|
||||
'x': 0.5,
|
||||
'xanchor': 'center',
|
||||
'font': {
|
||||
'size': 16
|
||||
}
|
||||
},
|
||||
xaxis_title='Request Index',
|
||||
yaxis_title='Time (milliseconds)',
|
||||
hovermode='x unified',
|
||||
legend=dict(orientation="v",
|
||||
yanchor="top",
|
||||
y=1,
|
||||
xanchor="left",
|
||||
x=1.02),
|
||||
width=1200,
|
||||
height=700,
|
||||
margin=dict(r=200))
|
||||
|
||||
# Calculate and add statistics
|
||||
self._add_statistics_annotation(fig, timing_data)
|
||||
|
||||
# Set default output filename if not provided
|
||||
if not output_file:
|
||||
output_file = 'time_breakdown.html'
|
||||
elif not output_file.endswith('.html'):
|
||||
output_file += '.html'
|
||||
|
||||
# Generate the plotly div
|
||||
plot_div = pyo.plot(fig,
|
||||
output_type='div',
|
||||
include_plotlyjs='cdn',
|
||||
auto_open=False)
|
||||
|
||||
# Generate descriptions HTML
|
||||
descriptions_html = self._generate_descriptions_html(timing_data)
|
||||
|
||||
# Combine into full HTML
|
||||
full_html = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Request Timing Breakdown</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 20px;
|
||||
}}
|
||||
.descriptions {{
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background-color: #f5f5f5;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
max-width: 1200px;
|
||||
}}
|
||||
.descriptions h2 {{
|
||||
margin-top: 0;
|
||||
color: #333;
|
||||
}}
|
||||
.metric-desc {{
|
||||
margin-bottom: 15px;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.metric-name {{
|
||||
font-weight: bold;
|
||||
color: #2c3e50;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{plot_div}
|
||||
{descriptions_html}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# Write to file
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(full_html)
|
||||
|
||||
print(f"Interactive time breakdown diagram saved to: {output_file}")
|
||||
print(f"Open the file in your web browser to interact with the chart!")
|
||||
|
||||
def _add_statistics_annotation(self, fig, timing_data: List[Dict]):
|
||||
"""Add statistics annotation to the plot."""
|
||||
# Calculate median times for each metric
|
||||
stats_lines = ['<b>Median Times (ms):</b>']
|
||||
total_times = []
|
||||
|
||||
for metric in self.config.metrics:
|
||||
times = [
|
||||
data.get(f'{metric.name}_time', 0) * 1000
|
||||
for data in timing_data
|
||||
]
|
||||
if any(t > 0 for t in times):
|
||||
median_time = np.median(times)
|
||||
stats_lines.append(f'{metric.display_name}: {median_time:.2f}')
|
||||
|
||||
# Calculate total time per request
|
||||
for data in timing_data:
|
||||
total = sum(
|
||||
data.get(f'{metric.name}_time', 0) * 1000
|
||||
for metric in self.config.metrics)
|
||||
total_times.append(total)
|
||||
|
||||
if total_times:
|
||||
median_total = np.median(total_times)
|
||||
stats_lines.append(f'<b>Total per Request: {median_total:.2f}</b>')
|
||||
|
||||
stats_lines.append(f'<b>Requests: {len(timing_data)}</b>')
|
||||
|
||||
stats_text = '<br>'.join(stats_lines)
|
||||
|
||||
fig.add_annotation(x=0.98,
|
||||
y=0.98,
|
||||
xref='paper',
|
||||
yref='paper',
|
||||
text=stats_text,
|
||||
showarrow=False,
|
||||
align='right',
|
||||
bgcolor='rgba(255, 255, 255, 0.8)',
|
||||
bordercolor='black',
|
||||
borderwidth=1,
|
||||
font=dict(size=10))
|
||||
|
||||
def _generate_descriptions_html(self, timing_data: List[Dict]) -> str:
|
||||
"""Generate HTML for metric descriptions section."""
|
||||
desc_items = []
|
||||
|
||||
for metric in self.config.metrics:
|
||||
times = [
|
||||
data.get(f'{metric.name}_time', 0) * 1000
|
||||
for data in timing_data
|
||||
]
|
||||
# Only include metrics that have non-zero data
|
||||
if any(t > 0 for t in times):
|
||||
desc_items.append(
|
||||
f'<div class="metric-desc">'
|
||||
f'<span class="metric-name">{metric.display_name}:</span> '
|
||||
f'{metric.description}'
|
||||
f'</div>')
|
||||
|
||||
if not desc_items:
|
||||
return ''
|
||||
|
||||
descriptions_html = f"""
|
||||
<div class="descriptions">
|
||||
<h2>Metric Descriptions</h2>
|
||||
{''.join(desc_items)}
|
||||
Reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serve/scripts/time_breakdown/README.md
|
||||
</div>
|
||||
"""
|
||||
return descriptions_html
|
||||
|
||||
def show_statistics(self, timing_data: List[Dict]):
|
||||
"""Show detailed statistics about the timing data."""
|
||||
if not timing_data:
|
||||
print("No timing data to analyze.")
|
||||
return
|
||||
|
||||
print("\n=== Timing Statistics ===")
|
||||
print(f"Total requests: {len(timing_data)}")
|
||||
|
||||
for metric in self.config.metrics:
|
||||
times = [data.get(f'{metric.name}_time', 0) for data in timing_data]
|
||||
if any(t > 0 for t in times):
|
||||
print(f"\n{metric.display_name} Times (seconds):")
|
||||
print(f" Range: {min(times):.3f} to {max(times):.3f}")
|
||||
print(f" Median: {np.median(times):.3f}")
|
||||
print(f" Description: {metric.description}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze and visualize TensorRT-LLM server time breakdown',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Analyze performance metrics and create timing diagram
|
||||
python time_breakdown.py perf_metrics.json
|
||||
|
||||
# Specify custom output file
|
||||
python time_breakdown.py perf_metrics.json -o my_timing.html
|
||||
|
||||
# Show statistics only (no diagram)
|
||||
python time_breakdown.py perf_metrics.json --stats-only
|
||||
|
||||
# Create diagram and show statistics
|
||||
python time_breakdown.py perf_metrics.json --show-stats
|
||||
""")
|
||||
|
||||
parser.add_argument('json_file',
|
||||
type=str,
|
||||
help='Path to the JSON performance metrics file')
|
||||
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--output',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Output HTML file path (default: time_breakdown.html)')
|
||||
|
||||
parser.add_argument('--stats-only',
|
||||
action='store_true',
|
||||
help='Show statistics only without creating diagram')
|
||||
|
||||
parser.add_argument('--show-stats',
|
||||
action='store_true',
|
||||
help='Show statistics in addition to creating diagram')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create analyzer
|
||||
analyzer = RequestTimeBreakdown()
|
||||
|
||||
# Parse the JSON file
|
||||
print(f"Parsing timing data from: {args.json_file}")
|
||||
timing_data = analyzer.parse_json_file(args.json_file)
|
||||
|
||||
if not timing_data:
|
||||
print("No timing data found in the file.")
|
||||
sys.exit(1)
|
||||
|
||||
# Show statistics if requested
|
||||
if args.stats_only or args.show_stats:
|
||||
analyzer.show_statistics(timing_data)
|
||||
|
||||
# Create diagram unless stats-only mode
|
||||
if not args.stats_only:
|
||||
analyzer.create_timing_diagram(timing_data, args.output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -20,6 +20,7 @@ l0_a10:
|
||||
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
|
||||
# test list either).
|
||||
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
|
||||
- unittest/others/test_time_breakdown.py
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
504
tests/unittest/others/test_time_breakdown.py
Normal file
504
tests/unittest/others/test_time_breakdown.py
Normal file
@ -0,0 +1,504 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Unit tests for time_breakdown module
|
||||
|
||||
Run tests with:
|
||||
python -m pytest tests/unittest/others/test_time_breakdown.py -v
|
||||
or
|
||||
python -m unittest tests.unittest.others.test_time_breakdown
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tensorrt_llm.serve.scripts.time_breakdown import (RequestDataParser,
|
||||
RequestTimeBreakdown,
|
||||
TimingMetric,
|
||||
TimingMetricsConfig)
|
||||
|
||||
|
||||
class TestTimingMetric(unittest.TestCase):
|
||||
"""Test TimingMetric class."""
|
||||
|
||||
def test_timing_metric_creation(self):
|
||||
"""Test basic TimingMetric creation."""
|
||||
metric = TimingMetric(name='test_metric',
|
||||
display_name='Test Metric',
|
||||
color='blue',
|
||||
description='Test description',
|
||||
start_field='start_time',
|
||||
end_field='end_time',
|
||||
server_type='ctx')
|
||||
|
||||
self.assertEqual(metric.name, 'test_metric')
|
||||
self.assertEqual(metric.display_name, 'Test Metric')
|
||||
self.assertEqual(metric.color, 'blue')
|
||||
self.assertEqual(metric.description, 'Test description')
|
||||
self.assertEqual(metric.start_field, 'start_time')
|
||||
self.assertEqual(metric.end_field, 'end_time')
|
||||
self.assertEqual(metric.server_type, 'ctx')
|
||||
|
||||
def test_calculate_duration_valid(self):
|
||||
"""Test duration calculation with valid timestamps."""
|
||||
metric = TimingMetric(name='test',
|
||||
display_name='Test',
|
||||
color='blue',
|
||||
description='Test',
|
||||
start_field='start_time',
|
||||
end_field='end_time')
|
||||
|
||||
timing_data = {'start_time': 1.0, 'end_time': 3.5}
|
||||
|
||||
duration = metric.calculate_duration(timing_data)
|
||||
self.assertEqual(duration, 2.5)
|
||||
|
||||
def test_calculate_duration_missing_start(self):
|
||||
"""Test duration calculation with missing start time."""
|
||||
metric = TimingMetric(name='test',
|
||||
display_name='Test',
|
||||
color='blue',
|
||||
description='Test',
|
||||
start_field='start_time',
|
||||
end_field='end_time')
|
||||
|
||||
timing_data = {'end_time': 3.5}
|
||||
|
||||
duration = metric.calculate_duration(timing_data)
|
||||
self.assertEqual(duration, 0.0)
|
||||
|
||||
def test_calculate_duration_missing_end(self):
|
||||
"""Test duration calculation with missing end time."""
|
||||
metric = TimingMetric(name='test',
|
||||
display_name='Test',
|
||||
color='blue',
|
||||
description='Test',
|
||||
start_field='start_time',
|
||||
end_field='end_time')
|
||||
|
||||
timing_data = {'start_time': 1.0, 'end_time': 0}
|
||||
|
||||
duration = metric.calculate_duration(timing_data)
|
||||
self.assertEqual(duration, 0.0)
|
||||
|
||||
def test_calculate_duration_negative(self):
|
||||
"""Test duration calculation doesn't produce negative values."""
|
||||
metric = TimingMetric(name='test',
|
||||
display_name='Test',
|
||||
color='blue',
|
||||
description='Test',
|
||||
start_field='start_time',
|
||||
end_field='end_time')
|
||||
|
||||
timing_data = {'start_time': 5.0, 'end_time': 3.5}
|
||||
|
||||
duration = metric.calculate_duration(timing_data)
|
||||
self.assertEqual(duration, 0.0)
|
||||
|
||||
|
||||
class TestTimingMetricsConfig(unittest.TestCase):
|
||||
"""Test TimingMetricsConfig class."""
|
||||
|
||||
def test_default_metrics_loaded(self):
|
||||
"""Test that default metrics are loaded."""
|
||||
config = TimingMetricsConfig()
|
||||
|
||||
# Should have multiple default metrics
|
||||
self.assertGreater(len(config.metrics), 0)
|
||||
|
||||
# Check for expected metric names
|
||||
metric_names = [m.name for m in config.metrics]
|
||||
self.assertIn('ctx_preprocessing', metric_names)
|
||||
self.assertIn('ctx_processing', metric_names)
|
||||
|
||||
def test_get_metric_by_name(self):
|
||||
"""Test retrieving a metric by name."""
|
||||
config = TimingMetricsConfig()
|
||||
|
||||
metric = config.get_metric_by_name('ctx_preprocessing')
|
||||
self.assertIsNotNone(metric)
|
||||
self.assertEqual(metric.name, 'ctx_preprocessing')
|
||||
|
||||
# Test non-existent metric
|
||||
metric = config.get_metric_by_name('non_existent')
|
||||
self.assertIsNone(metric)
|
||||
|
||||
def test_get_metrics_by_server(self):
|
||||
"""Test retrieving metrics by server type."""
|
||||
config = TimingMetricsConfig()
|
||||
|
||||
ctx_metrics = config.get_metrics_by_server('ctx')
|
||||
self.assertGreater(len(ctx_metrics), 0)
|
||||
|
||||
# All returned metrics should be for 'ctx' server
|
||||
for metric in ctx_metrics:
|
||||
self.assertEqual(metric.server_type, 'ctx')
|
||||
|
||||
def test_add_metric(self):
|
||||
"""Test adding a new metric."""
|
||||
config = TimingMetricsConfig()
|
||||
initial_count = len(config.metrics)
|
||||
|
||||
new_metric = TimingMetric(name='custom_metric',
|
||||
display_name='Custom Metric',
|
||||
color='red',
|
||||
description='Custom test metric',
|
||||
start_field='start',
|
||||
end_field='end')
|
||||
|
||||
config.add_metric(new_metric)
|
||||
self.assertEqual(len(config.metrics), initial_count + 1)
|
||||
self.assertIsNotNone(config.get_metric_by_name('custom_metric'))
|
||||
|
||||
def test_remove_metric(self):
|
||||
"""Test removing a metric."""
|
||||
config = TimingMetricsConfig()
|
||||
initial_count = len(config.metrics)
|
||||
|
||||
# Add a test metric first
|
||||
test_metric = TimingMetric(name='test_to_remove',
|
||||
display_name='Test',
|
||||
color='blue',
|
||||
description='Test',
|
||||
start_field='start',
|
||||
end_field='end')
|
||||
config.add_metric(test_metric)
|
||||
|
||||
# Remove it
|
||||
config.remove_metric('test_to_remove')
|
||||
self.assertEqual(len(config.metrics), initial_count)
|
||||
self.assertIsNone(config.get_metric_by_name('test_to_remove'))
|
||||
|
||||
|
||||
class TestRequestDataParser(unittest.TestCase):
|
||||
"""Test RequestDataParser class."""
|
||||
|
||||
def test_parse_aggregated_format(self):
|
||||
"""Test parsing aggregated (non-disaggregated) format."""
|
||||
parser = RequestDataParser()
|
||||
|
||||
request_data = {
|
||||
'request_id': 'req123',
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': 1.0,
|
||||
'arrival_time': 1.1,
|
||||
'first_scheduled_time': 1.2,
|
||||
'first_token_time': 1.5,
|
||||
'server_first_token_time': 1.6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parsed = parser.parse_request(request_data, 0)
|
||||
|
||||
self.assertEqual(parsed['request_index'], 'req123')
|
||||
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
|
||||
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
|
||||
self.assertEqual(parsed['ctx_first_scheduled_time'], 1.2)
|
||||
self.assertEqual(parsed['ctx_first_token_time'], 1.5)
|
||||
self.assertEqual(parsed['ctx_server_first_token_time'], 1.6)
|
||||
|
||||
# Gen metrics should be 0 in aggregated format
|
||||
self.assertEqual(parsed['gen_server_arrival_time'], 0)
|
||||
self.assertEqual(parsed['disagg_server_arrival_time'], 0)
|
||||
|
||||
def test_parse_disaggregated_format(self):
|
||||
"""Test parsing disaggregated format."""
|
||||
parser = RequestDataParser()
|
||||
|
||||
request_data = {
|
||||
'ctx_perf_metrics': {
|
||||
'request_id': 'req456',
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': 1.0,
|
||||
'arrival_time': 1.1,
|
||||
'first_scheduled_time': 1.2,
|
||||
'first_token_time': 1.5,
|
||||
'server_first_token_time': 1.6
|
||||
}
|
||||
}
|
||||
},
|
||||
'gen_perf_metrics': {
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': 2.0,
|
||||
'arrival_time': 2.1,
|
||||
'first_scheduled_time': 2.2,
|
||||
'first_token_time': 2.5,
|
||||
'server_first_token_time': 2.6
|
||||
}
|
||||
}
|
||||
},
|
||||
'disagg_server_arrival_time': 0.5,
|
||||
'disagg_server_first_token_time': 3.0
|
||||
}
|
||||
|
||||
parsed = parser.parse_request(request_data, 0)
|
||||
|
||||
self.assertEqual(parsed['request_index'], 'req456')
|
||||
|
||||
# Context metrics
|
||||
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
|
||||
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
|
||||
|
||||
# Generation metrics
|
||||
self.assertEqual(parsed['gen_server_arrival_time'], 2.0)
|
||||
self.assertEqual(parsed['gen_arrival_time'], 2.1)
|
||||
|
||||
# Disaggregation metrics
|
||||
self.assertEqual(parsed['disagg_server_arrival_time'], 0.5)
|
||||
self.assertEqual(parsed['disagg_server_first_token_time'], 3.0)
|
||||
|
||||
def test_parse_missing_fields(self):
|
||||
"""Test parsing with missing fields (should default to 0)."""
|
||||
parser = RequestDataParser()
|
||||
|
||||
request_data = {
|
||||
'request_id': 'req789',
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {}
|
||||
}
|
||||
}
|
||||
|
||||
parsed = parser.parse_request(request_data, 0)
|
||||
|
||||
# All timing fields should default to 0
|
||||
self.assertEqual(parsed['ctx_server_arrival_time'], 0)
|
||||
self.assertEqual(parsed['ctx_arrival_time'], 0)
|
||||
self.assertEqual(parsed['gen_server_arrival_time'], 0)
|
||||
|
||||
def test_parse_uses_index_as_fallback(self):
|
||||
"""Test that index is used when request_id is missing."""
|
||||
parser = RequestDataParser()
|
||||
|
||||
request_data = {'perf_metrics': {'timing_metrics': {}}}
|
||||
|
||||
parsed = parser.parse_request(request_data, 42)
|
||||
|
||||
self.assertEqual(parsed['request_index'], 42)
|
||||
|
||||
|
||||
class TestRequestTimeBreakdown(unittest.TestCase):
|
||||
"""Test RequestTimeBreakdown class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.analyzer = RequestTimeBreakdown()
|
||||
|
||||
# Create a temporary JSON file for testing
|
||||
self.test_data = [{
|
||||
'request_id': 0,
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': 1.0,
|
||||
'arrival_time': 1.1,
|
||||
'first_scheduled_time': 1.2,
|
||||
'first_token_time': 1.5,
|
||||
'server_first_token_time': 1.6
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'request_id': 1,
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': 2.0,
|
||||
'arrival_time': 2.1,
|
||||
'first_scheduled_time': 2.3,
|
||||
'first_token_time': 2.7,
|
||||
'server_first_token_time': 2.8
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
def test_parse_json_file(self):
|
||||
"""Test parsing a JSON file."""
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
|
||||
delete=False) as f:
|
||||
json.dump(self.test_data, f)
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
timing_data = self.analyzer.parse_json_file(temp_file)
|
||||
|
||||
self.assertEqual(len(timing_data), 2)
|
||||
|
||||
# Check first request
|
||||
self.assertEqual(timing_data[0]['request_index'], 0)
|
||||
self.assertEqual(timing_data[0]['ctx_server_arrival_time'], 1.0)
|
||||
|
||||
# Check that durations were calculated
|
||||
self.assertIn('ctx_preprocessing_time', timing_data[0])
|
||||
self.assertIn('ctx_queue_time', timing_data[0])
|
||||
|
||||
# Verify a specific duration calculation
|
||||
# ctx_preprocessing = ctx_arrival_time - ctx_server_arrival_time
|
||||
expected_preprocessing = 1.1 - 1.0
|
||||
self.assertAlmostEqual(timing_data[0]['ctx_preprocessing_time'],
|
||||
expected_preprocessing,
|
||||
places=5)
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_parse_json_file_not_found(self):
|
||||
"""Test parsing a non-existent file."""
|
||||
with self.assertRaises(SystemExit):
|
||||
self.analyzer.parse_json_file('non_existent_file.json')
|
||||
|
||||
def test_parse_json_file_invalid_json(self):
|
||||
"""Test parsing an invalid JSON file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
|
||||
delete=False) as f:
|
||||
f.write("{ invalid json")
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
with self.assertRaises(SystemExit):
|
||||
self.analyzer.parse_json_file(temp_file)
|
||||
finally:
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_create_timing_diagram(self):
|
||||
"""Test creating a timing diagram."""
|
||||
# Create sample timing data
|
||||
timing_data = [{
|
||||
'request_index': 0,
|
||||
'ctx_preprocessing_time': 0.1,
|
||||
'ctx_queue_time': 0.2,
|
||||
'ctx_processing_time': 0.3,
|
||||
'ctx_postprocessing_time': 0.05,
|
||||
'gen_preprocessing_time': 0,
|
||||
'gen_queue_time': 0,
|
||||
'gen_postprocessing_time': 0,
|
||||
'disagg_preprocessing_time': 0,
|
||||
'disagg_postprocessing_time': 0,
|
||||
}, {
|
||||
'request_index': 1,
|
||||
'ctx_preprocessing_time': 0.15,
|
||||
'ctx_queue_time': 0.25,
|
||||
'ctx_processing_time': 0.35,
|
||||
'ctx_postprocessing_time': 0.06,
|
||||
'gen_preprocessing_time': 0,
|
||||
'gen_queue_time': 0,
|
||||
'gen_postprocessing_time': 0,
|
||||
'disagg_preprocessing_time': 0,
|
||||
'disagg_postprocessing_time': 0,
|
||||
}]
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as f:
|
||||
temp_file = f.name
|
||||
|
||||
try:
|
||||
# Mock plotly to avoid actual file creation
|
||||
with patch(
|
||||
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
|
||||
) as mock_plot:
|
||||
self.analyzer.create_timing_diagram(timing_data, temp_file)
|
||||
|
||||
# Verify that plot was called
|
||||
mock_plot.assert_called_once()
|
||||
finally:
|
||||
if os.path.exists(temp_file):
|
||||
os.unlink(temp_file)
|
||||
|
||||
def test_create_timing_diagram_empty_data(self):
|
||||
"""Test creating a diagram with empty data."""
|
||||
# Should handle gracefully without creating a file
|
||||
with patch('builtins.print') as mock_print:
|
||||
self.analyzer.create_timing_diagram([])
|
||||
mock_print.assert_called_with("No timing data to visualize.")
|
||||
|
||||
def test_show_statistics(self):
|
||||
"""Test showing statistics."""
|
||||
timing_data = [{
|
||||
'ctx_preprocessing_time': 0.1,
|
||||
'ctx_queue_time': 0.2,
|
||||
'ctx_processing_time': 0.3,
|
||||
}, {
|
||||
'ctx_preprocessing_time': 0.15,
|
||||
'ctx_queue_time': 0.25,
|
||||
'ctx_processing_time': 0.35,
|
||||
}]
|
||||
|
||||
# Capture printed output
|
||||
with patch('builtins.print') as mock_print:
|
||||
self.analyzer.show_statistics(timing_data)
|
||||
|
||||
# Should have printed something
|
||||
self.assertTrue(mock_print.called)
|
||||
|
||||
# Check for expected content in printed output
|
||||
printed_output = ' '.join(
|
||||
[str(call[0][0]) for call in mock_print.call_args_list])
|
||||
self.assertIn('Total requests', printed_output)
|
||||
|
||||
def test_show_statistics_empty_data(self):
|
||||
"""Test showing statistics with empty data."""
|
||||
with patch('builtins.print') as mock_print:
|
||||
self.analyzer.show_statistics([])
|
||||
mock_print.assert_called_with("No timing data to analyze.")
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test using a custom configuration."""
|
||||
custom_config = TimingMetricsConfig()
|
||||
custom_config.add_metric(
|
||||
TimingMetric(name='custom_metric',
|
||||
display_name='Custom',
|
||||
color='red',
|
||||
description='Custom metric',
|
||||
start_field='custom_start',
|
||||
end_field='custom_end'))
|
||||
|
||||
analyzer = RequestTimeBreakdown(config=custom_config)
|
||||
|
||||
# Verify custom config is used
|
||||
self.assertIsNotNone(
|
||||
analyzer.config.get_metric_by_name('custom_metric'))
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the full workflow."""
|
||||
|
||||
def test_full_workflow(self):
|
||||
"""Test the complete workflow from file to diagram."""
|
||||
# Create test data
|
||||
test_data = [{
|
||||
'request_id': i,
|
||||
'perf_metrics': {
|
||||
'timing_metrics': {
|
||||
'server_arrival_time': float(i),
|
||||
'arrival_time': float(i) + 0.1,
|
||||
'first_scheduled_time': float(i) + 0.2,
|
||||
'first_token_time': float(i) + 0.5,
|
||||
'server_first_token_time': float(i) + 0.6
|
||||
}
|
||||
}
|
||||
} for i in range(5)]
|
||||
|
||||
# Create temporary files and run the complete workflow
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=True) as json_f, \
|
||||
tempfile.NamedTemporaryFile(suffix='.html', delete=True) as html_f:
|
||||
# Write test data to JSON file
|
||||
json.dump(test_data, json_f)
|
||||
json_f.flush() # Ensure data is written before reading
|
||||
|
||||
analyzer = RequestTimeBreakdown()
|
||||
timing_data = analyzer.parse_json_file(json_f.name)
|
||||
|
||||
# Verify parsing
|
||||
self.assertEqual(len(timing_data), 5)
|
||||
|
||||
# Mock the plot function to avoid actual file operations
|
||||
with patch(
|
||||
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
|
||||
):
|
||||
analyzer.create_timing_diagram(timing_data, html_f.name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user