TensorRT-LLMs/tests/unittest/others/test_time_breakdown.py
Yilin Fan f3224ccd32
[None][feat] Add disagg relay time to time breakdown tool (#8465)
Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com>
2025-10-30 18:21:45 -07:00

638 lines
24 KiB
Python

#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for time_breakdown module
Run tests with:
python -m pytest tests/unittest/others/test_time_breakdown.py -v
or
python -m unittest tests.unittest.others.test_time_breakdown
"""
import json
import math
import os
import tempfile
import unittest
from unittest.mock import patch
from tensorrt_llm.serve.scripts.time_breakdown import (RequestDataParser,
RequestTimeBreakdown,
TimingMetric,
TimingMetricsConfig)
class TestTimingMetric(unittest.TestCase):
"""Test TimingMetric class."""
def test_timing_metric_creation(self):
"""Test basic TimingMetric creation."""
metric = TimingMetric(name='test_metric',
display_name='Test Metric',
color='blue',
description='Test description',
start_field='start_time',
end_field='end_time',
server_type='ctx')
self.assertEqual(metric.name, 'test_metric')
self.assertEqual(metric.display_name, 'Test Metric')
self.assertEqual(metric.color, 'blue')
self.assertEqual(metric.description, 'Test description')
self.assertEqual(metric.start_field, 'start_time')
self.assertEqual(metric.end_field, 'end_time')
self.assertEqual(metric.server_type, 'ctx')
def test_calculate_duration_valid(self):
"""Test duration calculation with valid timestamps."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 1.0, 'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 2.5)
def test_calculate_duration_missing_start(self):
"""Test duration calculation with missing start time."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
def test_calculate_duration_missing_end(self):
"""Test duration calculation with missing end time."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 1.0, 'end_time': 0}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
def test_calculate_duration_negative(self):
"""Test duration calculation doesn't produce negative values."""
metric = TimingMetric(name='test',
display_name='Test',
color='blue',
description='Test',
start_field='start_time',
end_field='end_time')
timing_data = {'start_time': 5.0, 'end_time': 3.5}
duration = metric.calculate_duration(timing_data)
self.assertEqual(duration, 0.0)
class TestTimingMetricsConfig(unittest.TestCase):
"""Test TimingMetricsConfig class."""
def test_default_metrics_loaded(self):
"""Test that default metrics are loaded."""
config = TimingMetricsConfig()
# Should have multiple default metrics
self.assertGreater(len(config.metrics), 0)
# Check for expected metric names
metric_names = [m.name for m in config.metrics]
self.assertIn('ctx_preprocessing', metric_names)
self.assertIn('ctx_processing', metric_names)
def test_get_metric_by_name(self):
"""Test retrieving a metric by name."""
config = TimingMetricsConfig()
metric = config.get_metric_by_name('ctx_preprocessing')
self.assertIsNotNone(metric)
self.assertEqual(metric.name, 'ctx_preprocessing')
# Test non-existent metric
metric = config.get_metric_by_name('non_existent')
self.assertIsNone(metric)
def test_get_metrics_by_server(self):
"""Test retrieving metrics by server type."""
config = TimingMetricsConfig()
ctx_metrics = config.get_metrics_by_server('ctx')
self.assertGreater(len(ctx_metrics), 0)
# All returned metrics should be for 'ctx' server
for metric in ctx_metrics:
self.assertEqual(metric.server_type, 'ctx')
def test_disagg_relay_metric_exists(self):
"""Test that disagg_relay metric exists in default configuration."""
config = TimingMetricsConfig()
metric = config.get_metric_by_name('disagg_relay')
self.assertIsNotNone(metric)
self.assertEqual(metric.name, 'disagg_relay')
self.assertEqual(metric.display_name, 'Disagg Relay')
self.assertEqual(metric.start_field, 'ctx_server_first_token_time')
self.assertEqual(metric.end_field, 'gen_server_arrival_time')
self.assertEqual(metric.server_type, 'disagg')
def test_disagg_relay_calculation(self):
"""Test disagg_relay duration calculation."""
config = TimingMetricsConfig()
metric = config.get_metric_by_name('disagg_relay')
# Test with valid timing data
timing_data = {
'ctx_server_first_token_time': 1.5,
'gen_server_arrival_time': 2.0
}
duration = metric.calculate_duration(timing_data)
self.assertAlmostEqual(duration, 0.5, places=5)
# Test with missing fields
timing_data_missing = {'ctx_server_first_token_time': 1.5}
duration = metric.calculate_duration(timing_data_missing)
self.assertEqual(duration, 0.0)
def test_add_metric(self):
"""Test adding a new metric."""
config = TimingMetricsConfig()
initial_count = len(config.metrics)
new_metric = TimingMetric(name='custom_metric',
display_name='Custom Metric',
color='red',
description='Custom test metric',
start_field='start',
end_field='end')
config.add_metric(new_metric)
self.assertEqual(len(config.metrics), initial_count + 1)
self.assertIsNotNone(config.get_metric_by_name('custom_metric'))
def test_remove_metric(self):
"""Test removing a metric."""
config = TimingMetricsConfig()
initial_count = len(config.metrics)
# Add a test metric first
test_metric = TimingMetric(name='test_to_remove',
display_name='Test',
color='blue',
description='Test',
start_field='start',
end_field='end')
config.add_metric(test_metric)
# Remove it
config.remove_metric('test_to_remove')
self.assertEqual(len(config.metrics), initial_count)
self.assertIsNone(config.get_metric_by_name('test_to_remove'))
class TestRequestDataParser(unittest.TestCase):
"""Test RequestDataParser class."""
def test_parse_aggregated_format(self):
"""Test parsing aggregated (non-disaggregated) format."""
parser = RequestDataParser()
request_data = {
'request_id': 'req123',
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
}
parsed = parser.parse_request(request_data, 0)
self.assertEqual(parsed['request_index'], 'req123')
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
self.assertEqual(parsed['ctx_first_scheduled_time'], 1.2)
self.assertEqual(parsed['ctx_first_token_time'], 1.5)
self.assertEqual(parsed['ctx_server_first_token_time'], 1.6)
# Gen metrics should be NaN in aggregated format
self.assertTrue(math.isnan(parsed['gen_server_arrival_time']))
self.assertTrue(math.isnan(parsed['disagg_server_arrival_time']))
def test_parse_disaggregated_format(self):
"""Test parsing disaggregated format."""
parser = RequestDataParser()
request_data = {
'ctx_perf_metrics': {
'request_id': 'req456',
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
},
'gen_perf_metrics': {
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 2.0,
'arrival_time': 2.1,
'first_scheduled_time': 2.2,
'first_token_time': 2.5,
'server_first_token_time': 2.6
}
}
},
'disagg_server_arrival_time': 0.5,
'disagg_server_first_token_time': 3.0
}
parsed = parser.parse_request(request_data, 0)
self.assertEqual(parsed['request_index'], 'req456')
# Context metrics
self.assertEqual(parsed['ctx_server_arrival_time'], 1.0)
self.assertEqual(parsed['ctx_arrival_time'], 1.1)
# Generation metrics
self.assertEqual(parsed['gen_server_arrival_time'], 2.0)
self.assertEqual(parsed['gen_arrival_time'], 2.1)
# Disaggregation metrics
self.assertEqual(parsed['disagg_server_arrival_time'], 0.5)
self.assertEqual(parsed['disagg_server_first_token_time'], 3.0)
def test_parse_missing_fields(self):
"""Test parsing with missing fields (should default to 0)."""
parser = RequestDataParser()
request_data = {
'request_id': 'req789',
'perf_metrics': {
'timing_metrics': {}
}
}
parsed = parser.parse_request(request_data, 0)
# All timing fields should default to NaN
self.assertTrue(math.isnan(parsed['ctx_server_first_token_time']))
self.assertTrue(math.isnan(parsed['ctx_arrival_time']))
self.assertTrue(math.isnan(parsed['gen_server_arrival_time']))
def test_parse_uses_index_as_fallback(self):
"""Test that index is used when request_id is missing."""
parser = RequestDataParser()
request_data = {'perf_metrics': {'timing_metrics': {}}}
parsed = parser.parse_request(request_data, 42)
self.assertEqual(parsed['request_index'], 42)
def test_parse_disagg_relay_fields(self):
"""Test parsing disaggregated format includes disagg_relay timing fields."""
parser = RequestDataParser()
request_data = {
'ctx_perf_metrics': {
'request_id': 'req_relay',
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6 # End of disagg_relay
}
}
},
'gen_perf_metrics': {
'perf_metrics': {
'timing_metrics': {
'server_arrival_time':
2.0, # Start of disagg_relay (from gen perspective)
'arrival_time': 2.1,
'first_scheduled_time': 2.2,
'first_token_time': 2.5,
'server_first_token_time': 2.6
}
}
},
'disagg_server_arrival_time': 0.5,
'disagg_server_first_token_time': 3.0
}
parsed = parser.parse_request(request_data, 0)
# Verify that both fields required for disagg_relay are present
self.assertEqual(parsed['ctx_server_first_token_time'], 1.6)
self.assertEqual(parsed['gen_server_arrival_time'], 2.0)
# The relay time should be: gen_server_arrival_time - ctx_server_first_token_time
# = 2.0 - 1.6 = 0.4 seconds
config = TimingMetricsConfig()
disagg_relay_metric = config.get_metric_by_name('disagg_relay')
relay_duration = disagg_relay_metric.calculate_duration(parsed)
self.assertAlmostEqual(relay_duration, 0.4, places=5)
class TestRequestTimeBreakdown(unittest.TestCase):
"""Test RequestTimeBreakdown class."""
def setUp(self):
"""Set up test fixtures."""
self.analyzer = RequestTimeBreakdown()
# Create a temporary JSON file for testing
self.test_data = [{
'request_id': 0,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 1.0,
'arrival_time': 1.1,
'first_scheduled_time': 1.2,
'first_token_time': 1.5,
'server_first_token_time': 1.6
}
}
}, {
'request_id': 1,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': 2.0,
'arrival_time': 2.1,
'first_scheduled_time': 2.3,
'first_token_time': 2.7,
'server_first_token_time': 2.8
}
}
}]
def test_parse_json_file(self):
"""Test parsing a JSON file."""
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
delete=False) as f:
json.dump(self.test_data, f)
temp_file = f.name
try:
timing_data = self.analyzer.parse_json_file(temp_file)
self.assertEqual(len(timing_data), 2)
# Check first request
self.assertEqual(timing_data[0]['request_index'], 0)
self.assertEqual(timing_data[0]['ctx_server_arrival_time'], 1.0)
# Check that durations were calculated
self.assertIn('ctx_preprocessing_time', timing_data[0])
self.assertIn('ctx_queue_time', timing_data[0])
# Verify a specific duration calculation
# ctx_preprocessing = ctx_arrival_time - ctx_server_arrival_time
expected_preprocessing = 1.1 - 1.0
self.assertAlmostEqual(timing_data[0]['ctx_preprocessing_time'],
expected_preprocessing,
places=5)
finally:
os.unlink(temp_file)
def test_parse_json_file_not_found(self):
"""Test parsing a non-existent file."""
with self.assertRaises(SystemExit):
self.analyzer.parse_json_file('non_existent_file.json')
def test_parse_json_file_invalid_json(self):
"""Test parsing an invalid JSON file."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json',
delete=False) as f:
f.write("{ invalid json")
temp_file = f.name
try:
with self.assertRaises(SystemExit):
self.analyzer.parse_json_file(temp_file)
finally:
os.unlink(temp_file)
def test_create_timing_diagram(self):
"""Test creating a timing diagram."""
# Create sample timing data
timing_data = [{
'request_index': 0,
'ctx_preprocessing_time': 0.1,
'ctx_queue_time': 0.2,
'ctx_processing_time': 0.3,
'ctx_postprocessing_time': 0.05,
'gen_preprocessing_time': 0,
'gen_queue_time': 0,
'gen_postprocessing_time': 0,
'disagg_preprocessing_time': 0,
'disagg_postprocessing_time': 0,
}, {
'request_index': 1,
'ctx_preprocessing_time': 0.15,
'ctx_queue_time': 0.25,
'ctx_processing_time': 0.35,
'ctx_postprocessing_time': 0.06,
'gen_preprocessing_time': 0,
'gen_queue_time': 0,
'gen_postprocessing_time': 0,
'disagg_preprocessing_time': 0,
'disagg_postprocessing_time': 0,
}]
with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as f:
temp_file = f.name
try:
# Mock plotly to avoid actual file creation
with patch(
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
) as mock_plot:
self.analyzer.create_timing_diagram(timing_data, temp_file)
# Verify that plot was called
mock_plot.assert_called_once()
finally:
if os.path.exists(temp_file):
os.unlink(temp_file)
def test_create_timing_diagram_empty_data(self):
"""Test creating a diagram with empty data."""
# Should handle gracefully without creating a file
with patch('builtins.print') as mock_print:
self.analyzer.create_timing_diagram([])
mock_print.assert_called_with("No timing data to visualize.")
def test_show_statistics(self):
"""Test showing statistics."""
timing_data = [{
'ctx_preprocessing_time': 0.1,
'ctx_queue_time': 0.2,
'ctx_processing_time': 0.3,
}, {
'ctx_preprocessing_time': 0.15,
'ctx_queue_time': 0.25,
'ctx_processing_time': 0.35,
}]
# Capture printed output
with patch('builtins.print') as mock_print:
self.analyzer.show_statistics(timing_data)
# Should have printed something
self.assertTrue(mock_print.called)
# Check for expected content in printed output
printed_output = ' '.join(
[str(call[0][0]) for call in mock_print.call_args_list])
self.assertIn('Total requests', printed_output)
def test_show_statistics_empty_data(self):
"""Test showing statistics with empty data."""
with patch('builtins.print') as mock_print:
self.analyzer.show_statistics([])
mock_print.assert_called_with("No timing data to analyze.")
def test_custom_config(self):
"""Test using a custom configuration."""
custom_config = TimingMetricsConfig()
custom_config.add_metric(
TimingMetric(name='custom_metric',
display_name='Custom',
color='red',
description='Custom metric',
start_field='custom_start',
end_field='custom_end'))
analyzer = RequestTimeBreakdown(config=custom_config)
# Verify custom config is used
self.assertIsNotNone(
analyzer.config.get_metric_by_name('custom_metric'))
class TestIntegration(unittest.TestCase):
"""Integration tests for the full workflow."""
def test_full_workflow(self):
"""Test the complete workflow from file to diagram."""
# Create test data
test_data = [{
'request_id': i,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': float(i),
'arrival_time': float(i) + 0.1,
'first_scheduled_time': float(i) + 0.2,
'first_token_time': float(i) + 0.5,
'server_first_token_time': float(i) + 0.6
}
}
} for i in range(5)]
# Create temporary files and run the complete workflow
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=True) as json_f, \
tempfile.NamedTemporaryFile(suffix='.html', delete=True) as html_f:
# Write test data to JSON file
json.dump(test_data, json_f)
json_f.flush() # Ensure data is written before reading
analyzer = RequestTimeBreakdown()
timing_data = analyzer.parse_json_file(json_f.name)
# Verify parsing
self.assertEqual(len(timing_data), 5)
# Mock the plot function to avoid actual file operations
with patch(
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
):
analyzer.create_timing_diagram(timing_data, html_f.name)
def test_full_workflow_with_disagg_relay(self):
"""Test the complete workflow with disaggregated data including disagg_relay."""
# Create disaggregated test data
test_data = [{
'ctx_perf_metrics': {
'request_id': i,
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': float(i),
'arrival_time': float(i) + 0.1,
'first_scheduled_time': float(i) + 0.2,
'first_token_time': float(i) + 0.5,
'server_first_token_time': float(i) + 0.6
}
}
},
'gen_perf_metrics': {
'perf_metrics': {
'timing_metrics': {
'server_arrival_time': float(i) + 1.0,
'arrival_time': float(i) + 1.1,
'first_scheduled_time': float(i) + 1.2,
'first_token_time': float(i) + 1.5,
'server_first_token_time': float(i) + 1.6
}
}
},
'disagg_server_arrival_time': float(i) - 0.5,
'disagg_server_first_token_time': float(i) + 2.0
} for i in range(3)]
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=True) as json_f, \
tempfile.NamedTemporaryFile(suffix='.html', delete=True) as html_f:
# Write test data to JSON file
json.dump(test_data, json_f)
json_f.flush()
analyzer = RequestTimeBreakdown()
timing_data = analyzer.parse_json_file(json_f.name)
# Verify parsing
self.assertEqual(len(timing_data), 3)
# Verify disagg_relay_time is calculated
for data in timing_data:
self.assertIn('disagg_relay_time', data)
# Expected relay time: gen_server_arrival_time - ctx_server_first_token_time
# = (i + 1.0) - (i + 0.6) = 0.4
self.assertAlmostEqual(data['disagg_relay_time'], 0.4, places=5)
# Mock the plot function to avoid actual file operations
with patch(
'tensorrt_llm.serve.scripts.time_breakdown.time_breakdown.pyo.plot'
):
analyzer.create_timing_diagram(timing_data, html_f.name)
if __name__ == '__main__':
unittest.main()