[TRTLLM-10851][feat] Add line_profiler tool for host overhead analysis. (#11232)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-02-15 16:18:10 +08:00 committed by GitHub
parent b003355050
commit ed404f9298
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1080 additions and 2 deletions

View File

@ -40,3 +40,4 @@ aiperf==0.3.0
nanobind>=2.9.0
nixl==0.8.0
hf-transfer==0.1.9
line_profiler

View File

@ -36,6 +36,8 @@ from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import CpType
from tensorrt_llm.runtime.generation import CUASSERT
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
from tensorrt_llm.tools.profiler.host_profile_tools.host_profiler import \
host_profiler_context
from ..distributed import Distributed
from ..expert_statistic import ExpertStatistic
@ -560,8 +562,12 @@ class PyExecutor:
def _event_loop_wrapper(self):
try:
with customized_gc_thresholds(
self.garbage_collection_gen0_threshold):
# Skip line profiler during warmup/memory estimation phase to avoid
# saving incomplete results that would be overwritten anyway
enable_profiler = bool(os.environ.get(
"TLLM_LINE_PROFILER_PATH")) and not self.is_warmup
with host_profiler_context(enable=enable_profiler), \
customized_gc_thresholds(self.garbage_collection_gen0_threshold):
self.event_loop()
except Exception as e:
logger.error(f"Error in event loop: {e}")

View File

@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler tools for TensorRT-LLM."""

View File

@ -0,0 +1,15 @@
"""Host-side profiling tools for CPU overhead analysis."""
from .host_profiler import (
HostProfiler,
get_global_profiler,
host_profiler_context,
set_global_profiler,
)
__all__ = [
"HostProfiler",
"get_global_profiler",
"host_profiler_context",
"set_global_profiler",
]

View File

@ -0,0 +1,763 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Host-side profiler for analyzing CPU overhead in the PyExecutor.
This module provides a HostProfiler class that wraps line_profiler to measure
line-by-line execution time of critical functions in the executor worker thread.
Usage:
Set environment variable TLLM_LINE_PROFILER_PATH to enable:
TLLM_LINE_PROFILER_PATH=./lp_results.txt pytest ...
Or use programmatically:
profiler = HostProfiler(output_path="./results.txt")
profiler.start()
# ... run code ...
profiler.stop()
"""
import importlib
import os
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
from tensorrt_llm.logger import logger
# Environment variable to enable line_profiler output path.
LINE_PROFILER_PATH_ENV_VAR = "TLLM_LINE_PROFILER_PATH"
# Environment variable to specify additional functions to profile (comma-separated).
# Format: "module.Class.method,module.Class.method2,..."
LINE_PROFILER_FUNCTIONS_ENV_VAR = "TLLM_LINE_PROFILER_FUNCTIONS"
@dataclass
class ProfileTarget:
"""Represents a function to be profiled.
Supports both class methods and standalone module-level functions.
Examples:
Class method:
ProfileTarget("module.path", "ClassName", "method_name")
-> resolves to module.path.ClassName.method_name
Standalone function:
ProfileTarget("module.path", None, "function_name")
-> resolves to module.path.function_name
"""
module_path: str
class_name: Optional[str] # None for standalone functions
method_name: str
@property
def full_path(self) -> str:
if self.class_name is None:
return f"{self.module_path}.{self.method_name}"
return f"{self.module_path}.{self.class_name}.{self.method_name}"
@property
def is_standalone(self) -> bool:
"""Returns True if this is a standalone function (not a class method)."""
return self.class_name is None
def resolve(self) -> Optional[Callable]:
"""Resolve the target to an actual function object.
Returns:
The unwrapped method/function (inner function), or None if resolution fails.
Note:
We MUST unwrap decorated functions (e.g., @torch.inference_mode,
@nvtx_range) because line_profiler traces by __code__ identity.
When a function is wrapped by decorators like @torch.inference_mode(),
the wrapper's __code__ is different from the actual function's __code__.
The wrapper only contains a few lines that call the inner function,
so line_profiler would only see those wrapper lines, not the actual
function body we want to profile.
By unwrapping, we get the actual function's __code__ which allows
line_profiler to trace the real function lines.
"""
try:
module = importlib.import_module(self.module_path)
if self.is_standalone:
# Standalone module-level function
func = getattr(module, self.method_name)
else:
# Class method
cls = getattr(module, self.class_name)
func = getattr(cls, self.method_name)
# Unwrap decorated functions to get the actual inner function.
# This is necessary for @torch.inference_mode(), @nvtx_range(), etc.
# Without unwrapping, line_profiler would only trace the wrapper's
# __code__, not the actual function body.
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to resolve profile target {self.full_path}: {e}")
return None
# Default functions to profile for host overhead analysis
# Hierarchical config: {module_path: {class_name: [method_names]}}
# Use None as class_name for standalone module-level functions
#
# Wildcard support:
# - ["*"] as method list: Profile all methods of the class
# - {None: ["*"]}: Profile all standalone functions in the module
# - {"*": ["*"]}: Profile all classes and all their methods in the module
_PYEXEC = "tensorrt_llm._torch.pyexecutor"
_DEFAULT_PROFILE_CONFIG: Dict[str, Dict[Optional[str], List[str]]] = {
f"{_PYEXEC}.py_executor": {
"PyExecutor": [
"_prepare_and_schedule_batch",
"_schedule",
"_forward_step",
"_sample_async",
"_update_requests",
"_update_request_states",
"_fetch_and_activate_new_requests",
"_handle_responses",
"_handle_canceled_requests",
"_enqueue_responses",
],
},
f"{_PYEXEC}.sampler": {
"TorchSampler": [
"sample_async",
"update_requests",
"_process_requests",
"_write_finish_reasons",
"_prepare_beam_search",
"_select_generated_logits",
"_sample_batched_by_strategy",
],
# Standalone module-level functions (use None as class_name)
None: [
"_group_requests_by_strategy_key",
],
},
f"{_PYEXEC}.resource_manager": {
"ResourceManager": ["prepare_resources", "update_resources", "free_resources"],
"KVCacheManager": ["prepare_resources", "update_resources"],
},
f"{_PYEXEC}.scheduler": {
"RequestScheduler": ["schedule_request"],
},
f"{_PYEXEC}.executor_request_queue": {
"ExecutorRequestQueue": [
"_fetch_new_requests_attention_tp",
"_fetch_new_requests_attention_dp",
"_fetch_and_process_requests",
"_merge_requests",
"fetch_new_requests",
],
},
}
def _get_all_methods_from_class(cls: type) -> List[str]:
"""Get all user-defined methods from a class (excluding inherited, dunder, nested classes, and properties).
Only includes items that have actual executable code (__code__ attribute):
- Regular instance methods (def foo(self): ...)
- Static methods (@staticmethod)
- Class methods (@classmethod)
Excludes:
- Nested classes (e.g., dataclasses like Args, Store)
- Properties (usually trivial getters, complex to profile)
- Constants and type aliases
- Dunder methods (__init__, __repr__, etc.)
Args:
cls: The class to introspect.
Returns:
List of method names defined directly on the class.
"""
import inspect
methods = []
for name, member in cls.__dict__.items():
# Skip dunder methods
if name.startswith("__") and name.endswith("__"):
continue
# Skip nested classes
if inspect.isclass(member):
continue
# Skip properties (usually trivial, complex to handle with line_profiler)
if isinstance(member, property):
continue
# Extract the underlying function from descriptors
func = None
if isinstance(member, staticmethod):
func = member.__func__
elif isinstance(member, classmethod):
func = member.__func__
elif inspect.isfunction(member):
func = member
# Only include if we have a valid function with executable code
if func is not None and hasattr(func, "__code__"):
methods.append(name)
return methods
def _get_all_functions_from_module(module_path: str) -> List[str]:
"""Get all user-defined functions from a module (excluding imported ones).
Args:
module_path: The module path to introspect.
Returns:
List of function names defined directly in the module.
"""
import inspect
try:
module = importlib.import_module(module_path)
except ImportError as e:
logger.warning(f"Failed to import module {module_path} for introspection: {e}")
return []
functions = []
for name, member in inspect.getmembers(module, inspect.isfunction):
# Only include functions defined in this module (not imported)
if member.__module__ == module_path:
# Skip private functions starting with underscore if desired
# For profiling, we likely want to include them
functions.append(name)
return functions
def _get_all_classes_from_module(module_path: str) -> List[str]:
"""Get all user-defined classes from a module (excluding imported ones).
Args:
module_path: The module path to introspect.
Returns:
List of class names defined directly in the module.
"""
import inspect
try:
module = importlib.import_module(module_path)
except ImportError as e:
logger.warning(f"Failed to import module {module_path} for introspection: {e}")
return []
classes = []
for name, member in inspect.getmembers(module, inspect.isclass):
# Only include classes defined in this module (not imported)
if member.__module__ == module_path:
classes.append(name)
return classes
def _expand_profile_config(
config: Dict[str, Dict[Optional[str], List[str]]],
) -> List[ProfileTarget]:
"""Expand hierarchical config into a flat list of ProfileTarget objects.
Args:
config: Hierarchical config mapping module_path -> class_name -> method_names.
Use None as class_name for standalone module-level functions.
Special markers:
- "*" in method list: Profile all methods of the class
- "*" as class_name key with ["*"]: Profile all classes and their methods
- None as class_name key with ["*"]: Profile all standalone functions
Returns:
List of ProfileTarget objects.
Examples:
# Profile all methods of TorchSampler class:
{"module.sampler": {"TorchSampler": ["*"]}}
# Profile all standalone functions in a module:
{"module.sampler": {None: ["*"]}}
# Profile all classes and all their methods in a module:
{"module.sampler": {"*": ["*"]}}
# Mix explicit and wildcard:
{"module.sampler": {"TorchSampler": ["*"], "OtherClass": ["specific_method"]}}
"""
targets = []
for module_path, classes in config.items():
for class_name, methods in classes.items():
# Handle wildcard class_name: "*" means all classes in the module
if class_name == "*":
if methods == ["*"]:
# Profile all methods of all classes in the module
all_classes = _get_all_classes_from_module(module_path)
for cls_name in all_classes:
try:
module = importlib.import_module(module_path)
cls = getattr(module, cls_name)
all_methods = _get_all_methods_from_class(cls)
for method_name in all_methods:
targets.append(ProfileTarget(module_path, cls_name, method_name))
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to introspect {module_path}.{cls_name}: {e}")
continue
# Handle wildcard methods: ["*"] means all methods of the class/module
if methods == ["*"]:
if class_name is None:
# All standalone functions in the module
all_funcs = _get_all_functions_from_module(module_path)
for func_name in all_funcs:
targets.append(ProfileTarget(module_path, None, func_name))
else:
# All methods of a specific class
try:
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
all_methods = _get_all_methods_from_class(cls)
for method_name in all_methods:
targets.append(ProfileTarget(module_path, class_name, method_name))
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to introspect {module_path}.{class_name}: {e}")
continue
# Normal case: explicit method list
for method_name in methods:
targets.append(ProfileTarget(module_path, class_name, method_name))
return targets
DEFAULT_PROFILE_TARGETS: List[ProfileTarget] = _expand_profile_config(_DEFAULT_PROFILE_CONFIG)
class HostProfiler:
"""Host-side profiler for measuring CPU overhead in the executor.
This class wraps line_profiler to provide line-by-line timing analysis
of critical functions in the PyExecutor worker thread.
Attributes:
output_path: Path to save profiling results.
targets: List of ProfileTarget objects specifying functions to profile.
enabled: Whether profiling is currently active.
Example:
>>> profiler = HostProfiler(output_path="./results.txt")
>>> profiler.add_target(
... ProfileTarget(
... module_path="my_module",
... class_name="MyClass",
... method_name="my_method",
... )
... )
>>> profiler.start()
>>> # ... run code ...
>>> profiler.stop()
"""
def __init__(
self,
output_path: Optional[str] = None,
targets: Optional[List[ProfileTarget]] = None,
use_defaults: bool = True,
):
"""Initialize the host profiler.
Args:
output_path: Path to save results. If None, uses env var TLLM_LINE_PROFILER_PATH.
targets: List of ProfileTarget objects. If None and use_defaults=True,
uses DEFAULT_PROFILE_TARGETS.
use_defaults: Whether to include default profile targets.
"""
self.output_path = output_path or os.environ.get(LINE_PROFILER_PATH_ENV_VAR)
self.targets: List[ProfileTarget] = []
self._line_profiler = None
self._enabled = False
# Add default targets if requested
if use_defaults:
self.targets.extend(DEFAULT_PROFILE_TARGETS)
# Add custom targets
if targets:
self.targets.extend(targets)
# Parse additional targets from environment variable
self._parse_env_targets()
def _parse_env_targets(self) -> None:
"""Parse additional profile targets from environment variable.
Supported formats:
- module.Class.method -> class method
- module::function -> standalone function (uses :: as delimiter)
"""
env_funcs = os.environ.get(LINE_PROFILER_FUNCTIONS_ENV_VAR, "")
if not env_funcs:
return
for func_path in env_funcs.split(","):
func_path = func_path.strip()
if not func_path:
continue
# Check for standalone function format: "module::function"
if "::" in func_path:
parts = func_path.rsplit("::", 1)
if len(parts) != 2:
logger.warning(
f"Invalid standalone function path '{func_path}'. "
"Expected format: module.path::function_name"
)
continue
module_path, method_name = parts
self.targets.append(
ProfileTarget(
module_path=module_path,
class_name=None, # Standalone function
method_name=method_name,
)
)
else:
# Class method format: "module.Class.method"
parts = func_path.rsplit(".", 2)
if len(parts) < 3:
logger.warning(
f"Invalid function path '{func_path}'. Expected format: "
"module.Class.method (class method) or module::function (standalone)"
)
continue
# Handle nested module paths
method_name = parts[-1]
class_name = parts[-2]
module_path = ".".join(parts[:-2])
self.targets.append(
ProfileTarget(
module_path=module_path,
class_name=class_name,
method_name=method_name,
)
)
def add_target(self, target: ProfileTarget) -> "HostProfiler":
"""Add a profile target.
Args:
target: The ProfileTarget to add.
Returns:
Self for chaining.
"""
self.targets.append(target)
return self
def add_function(
self,
module_path: str,
class_name: Optional[str],
method_name: str,
) -> "HostProfiler":
"""Add a function to profile by specifying its path.
Args:
module_path: The module path (e.g., "tensorrt_llm._torch.pyexecutor.sampler")
class_name: The class name (e.g., "TorchSampler"), or None for standalone functions
method_name: The method/function name (e.g., "_process_requests")
Returns:
Self for chaining.
Examples:
# Add a class method
profiler.add_function("my.module", "MyClass", "my_method")
# Add a standalone function
profiler.add_function("my.module", None, "my_function")
"""
return self.add_target(
ProfileTarget(
module_path=module_path,
class_name=class_name,
method_name=method_name,
)
)
def add_standalone_function(
self,
module_path: str,
function_name: str,
) -> "HostProfiler":
"""Add a standalone module-level function to profile.
This is a convenience method for adding standalone functions
(not class methods).
Args:
module_path: The module path (e.g., "tensorrt_llm._torch.pyexecutor.sampler")
function_name: The function name (e.g., "_group_requests_by_strategy_key")
Returns:
Self for chaining.
"""
return self.add_target(
ProfileTarget(
module_path=module_path,
class_name=None,
method_name=function_name,
)
)
def clear_targets(self) -> "HostProfiler":
"""Clear all profile targets (including defaults).
This is useful when you want to start with a clean slate and add
only specific targets, or when you want to replace all default
targets with a custom set.
Returns:
Self for chaining.
Example:
# Clear defaults and add only specific targets
profiler = HostProfiler(use_defaults=True)
profiler.clear_targets().add_function("my.module", "MyClass", "method")
# Or start fresh without defaults
profiler = HostProfiler(use_defaults=False)
"""
self.targets.clear()
return self
@property
def is_available(self) -> bool:
"""Check if line_profiler is available."""
return True
@property
def should_profile(self) -> bool:
"""Check if profiling should be enabled (output path is set)."""
return self.output_path is not None
@property
def enabled(self) -> bool:
"""Check if profiling is currently active."""
return self._enabled
def start(self) -> bool:
"""Start profiling.
Returns:
True if profiling started successfully, False otherwise.
"""
if not self.should_profile:
logger.info("Line profiler not enabled (no output path specified)")
return False
if not self.is_available:
logger.warning("line_profiler not installed. Install with: pip install line_profiler")
return False
if self._enabled:
logger.warning("Line profiler already started")
return True
try:
from line_profiler import LineProfiler
self._line_profiler = LineProfiler()
# Add all target functions
resolved_count = 0
for target in self.targets:
func = target.resolve()
if func is not None:
logger.info(
f"line profiler func code ID: {id(func.__code__)}, target: {target.full_path}"
)
self._line_profiler.add_function(func)
resolved_count += 1
if resolved_count == 0:
logger.warning("No profile targets could be resolved")
self._line_profiler = None
return False
self._line_profiler.enable()
self._enabled = True
self._profiler_thread_id = threading.current_thread().ident
logger.info(
f"Line profiler enabled with {resolved_count}/{len(self.targets)} targets. "
f"Thread ID: {self._profiler_thread_id}, Thread name: {threading.current_thread().name}. "
f"Results will be saved to: {self.output_path}"
)
dump_profiler_functions()
return True
except Exception as e:
logger.error(f"Failed to start line profiler: {e}")
self._line_profiler = None
return False
def stop(self) -> bool:
"""Stop profiling and save results.
Returns:
True if results were saved successfully, False otherwise.
"""
if not self._enabled or self._line_profiler is None:
return False
try:
self._line_profiler.disable()
self._enabled = False
# Save results
with open(self.output_path, "w") as f:
self._line_profiler.print_stats(stream=f)
logger.info(f"Line profiler results saved to: {self.output_path}")
return True
except Exception as e:
logger.error(f"Failed to save line profiler results: {e}")
return False
finally:
self._line_profiler = None
@contextmanager
def profile(self):
"""Context manager for profiling.
Usage:
with profiler.profile():
# ... code to profile ...
"""
started = self.start()
try:
yield self
finally:
if started:
self.stop()
def get_stats_string(self) -> Optional[str]:
"""Get profiling stats as a string (without saving to file).
Returns:
Stats string if profiling is active, None otherwise.
"""
if self._line_profiler is None:
return None
import io
stream = io.StringIO()
self._line_profiler.print_stats(stream=stream)
return stream.getvalue()
def list_targets(self) -> List[str]:
"""List all configured profile targets.
Returns:
List of target paths.
"""
return [t.full_path for t in self.targets]
# Global profiler instance for use in worker thread
_global_profiler: Optional[HostProfiler] = None
def get_global_profiler() -> Optional[HostProfiler]:
"""Get the global profiler instance."""
return _global_profiler
def dump_profiler_functions() -> None:
"""Print all functions registered with the line profiler for debugging.
Only dumps on rank 0 to avoid interleaved output from multiple ranks.
"""
# Import here to avoid circular imports and handle cases where MPI is not initialized
try:
from tensorrt_llm._utils import mpi_rank
if mpi_rank() != 0:
return
except Exception:
pass # If MPI is not available, proceed (single-rank case)
profiler = get_global_profiler()
if profiler is None or profiler._line_profiler is None:
logger.info("No line profiler active")
return
lp = profiler._line_profiler
logger.info(f"=== Line Profiler State: {len(lp.functions)} functions registered ===")
for func in lp.functions:
logger.info(f" {func.__module__}.{func.__qualname__}, code id: {id(func.__code__)}")
logger.info("=== End Line Profiler State ===")
def set_global_profiler(profiler: Optional[HostProfiler]) -> None:
"""Set the global profiler instance."""
global _global_profiler
_global_profiler = profiler
@contextmanager
def host_profiler_context(enable: bool = True, output_path: Optional[str] = None):
"""Context manager for host profiling in the worker thread.
This is the main entry point for profiling in PyExecutor._event_loop_wrapper.
Args:
output_path: Path to save results. If None, uses env var.
Usage:
with host_profiler_context():
# ... event loop code ...
"""
if not enable:
yield None
return
profiler = HostProfiler(output_path=output_path)
set_global_profiler(profiler)
started = profiler.start()
try:
yield profiler
finally:
if started:
profiler.stop()
set_global_profiler(None)

View File

@ -0,0 +1,278 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the HostProfiler module.
Tests cover:
1. Core API functionality (add_*, clear_targets, chaining)
2. Line profiler integration with report validation
3. E2E test with actual model inference
"""
import os
import re
import sys
import tempfile
import pytest
# Add path for test utilities
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from utils.llm_data import llm_models_root
from tensorrt_llm.tools.profiler.host_profile_tools.host_profiler import HostProfiler, ProfileTarget
def _sample_function_to_profile(n: int) -> int:
"""A simple function with multiple lines to profile."""
total = 0
for i in range(n):
total += i
return total
def _another_function(x: int, y: int) -> int:
"""Another function for testing."""
result = x + y
result *= 2
return result
class _SampleClassToProfile:
"""A sample class with methods to profile."""
def instance_method(self, n: int) -> int:
"""Instance method to profile."""
result = 0
for i in range(n):
result += i * 2
return result
def _patch_mpi_pool_session_for_env(mocker, env_vars: dict):
"""Patch MpiPoolSession to propagate env vars to MPI workers."""
from mpi4py.futures import MPIPoolExecutor
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
def patched_start_mpi_pool(self):
assert not self.mpi_pool, "MPI session already started"
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, path=sys.path, env=env_vars)
mocker.patch.object(MpiPoolSession, "_start_mpi_pool", patched_start_mpi_pool)
def test_add_and_clear_targets():
"""Test add_*, clear_targets(), and chaining work correctly."""
profiler = HostProfiler(use_defaults=False)
# Test add_* methods with chaining
result = (
profiler.add_function("m1", "C1", "f1")
.add_standalone_function("m2", "f2")
.add_target(ProfileTarget("m3", "C3", "f3"))
)
assert result is profiler
assert len(profiler.targets) == 3
assert profiler.list_targets() == ["m1.C1.f1", "m2.f2", "m3.C3.f3"]
# Test clear_targets with chaining
result = profiler.clear_targets()
assert result is profiler
assert len(profiler.targets) == 0
# Test add after clear
profiler.add_function("os", None, "getcwd")
assert len(profiler.targets) == 1
def test_defaults_and_clear():
"""Test that defaults are loaded and can be cleared."""
profiler = HostProfiler(use_defaults=True)
initial_count = len(profiler.targets)
assert initial_count > 10, "Should have many default targets"
# Clear all and add custom
profiler.clear_targets().add_standalone_function("os", "getcwd").add_standalone_function(
"os.path", "join"
)
assert len(profiler.targets) == 2
assert "os.getcwd" in profiler.list_targets()
assert "os.path.join" in profiler.list_targets()
def test_profiling_cycle_and_report_validation():
"""Test complete profiling cycle and validate report format.
This is the main unit test that validates:
1. Profiler starts and stops correctly
2. Report contains Timer unit header
3. Report contains column headers (Line, Hits, Time)
4. Report contains profiled function name
5. Report contains timing data
6. Only profiled functions appear (not non-profiled ones)
"""
try:
import line_profiler # noqa: F401
except ImportError:
pytest.skip("line_profiler not installed")
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
output_path = f.name
try:
# Use clear_targets to ensure only our test functions are profiled
profiler = HostProfiler(output_path=output_path, use_defaults=True)
profiler.clear_targets().add_standalone_function(
__name__, "_sample_function_to_profile"
).add_function(__name__, "_SampleClassToProfile", "instance_method")
assert len(profiler.targets) == 2
# Start profiling
assert profiler.start() is True
assert profiler.enabled is True
# Execute profiled functions
sample_obj = _SampleClassToProfile()
for _ in range(50):
_sample_function_to_profile(10)
sample_obj.instance_method(5)
# Execute non-profiled function - should NOT appear in output
_another_function(3, 4)
# Stop and save
assert profiler.stop() is True
assert profiler.enabled is False
# Validate output file exists
assert os.path.exists(output_path)
with open(output_path) as f:
content = f.read()
# --- Report Format Validation ---
# 1. Timer unit header
assert "Timer unit:" in content, "Report missing 'Timer unit:' header"
# 2. Column headers
assert "Line" in content, "Report missing 'Line' column header"
assert "Hits" in content, "Report missing 'Hits' column header"
assert "Time" in content, "Report missing 'Time' column header"
# 3. Profiled functions appear
assert "_sample_function_to_profile" in content, "Profiled function not in output"
assert "instance_method" in content, "Profiled method not in output"
# 4. Non-profiled function should NOT appear
assert "_another_function" not in content, "Non-profiled function should NOT be in output"
# 5. Should have timing data (lines with numbers)
lines = content.split("\n")
data_lines = [
line for line in lines if line.strip() and any(char.isdigit() for char in line)
]
assert len(data_lines) > 5, "Report should have multiple lines with timing data"
print("\n=== Report Validation Passed ===")
print(f"Output size: {len(content)} bytes")
print(f"Data lines: {len(data_lines)}")
finally:
if os.path.exists(output_path):
os.unlink(output_path)
# Core PyExecutor methods that are always executed during inference.
# These are stable, fundamental methods unlikely to be renamed.
E2E_PROFILE_TARGETS = [
"tensorrt_llm._torch.pyexecutor.py_executor.PyExecutor._forward_step",
"tensorrt_llm._torch.pyexecutor.py_executor.PyExecutor._schedule",
"tensorrt_llm._torch.pyexecutor.sampler.TorchSampler.sample_async",
]
@pytest.fixture
def tinyllama_path():
"""Get TinyLlama model path."""
model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"
if not model_path.exists():
pytest.skip(f"TinyLlama model not found at {model_path}")
return str(model_path)
def test_e2e_profiler_with_model(tinyllama_path, mocker):
"""E2E test: verify profiler works with actual model inference.
Clears default profile targets and adds only specific targets,
then verifies those targets appear in the report with non-zero timing.
"""
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
from tensorrt_llm.tools.profiler.host_profile_tools import host_profiler
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "profiler_output.txt")
# Clear default targets and use only our specific targets
mocker.patch.object(host_profiler, "DEFAULT_PROFILE_TARGETS", [])
# Patch MPI to propagate env vars to workers
_patch_mpi_pool_session_for_env(
mocker,
{
"TLLM_LINE_PROFILER_PATH": output_path,
"TLLM_LINE_PROFILER_FUNCTIONS": ", ".join(E2E_PROFILE_TARGETS),
},
)
with LLM(
model=tinyllama_path,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.3),
) as llm:
# Generate enough tokens to ensure profiled methods are executed
prompts = ["Hello, how are you?", "What is AI?"]
outputs = llm.generate(prompts, SamplingParams(max_tokens=16, end_id=-1))
assert len(outputs) == len(prompts)
# Validate output file was created
assert os.path.exists(output_path), f"Profiler output not created at {output_path}"
with open(output_path) as f:
content = f.read()
# Validate report format
assert "Timer unit:" in content, "Missing Timer unit header"
# Verify all specified targets were profiled with actual timing data
# Format: "Total time: X s" then "File: ..." then "Function: ClassName.method_name at line X"
expected_methods = ["_forward_step", "_schedule", "sample_async"]
for method in expected_methods:
# Find block: "Total time: X s" followed by "File: ..." then "Function: ...method_name..."
pattern = (
rf"Total time:\s*([\d.e+-]+)\s*s\nFile:.*?\nFunction:.*\.{method}\s+at line \d+"
)
match = re.search(pattern, content)
assert match, f"Method '{method}' not found in profiler output"
total_time = float(match.group(1))
assert total_time > 0, f"Method '{method}' has zero total time - not actually profiled"
print("\n=== E2E Passed ===")
print(f"Output: {len(content)} bytes")
print(f"Verified methods with timing data: {expected_methods}")