mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
[TRTLLM-10851][feat] Add line_profiler tool for host overhead analysis. (#11232)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
b003355050
commit
ed404f9298
@ -40,3 +40,4 @@ aiperf==0.3.0
|
||||
nanobind>=2.9.0
|
||||
nixl==0.8.0
|
||||
hf-transfer==0.1.9
|
||||
line_profiler
|
||||
|
||||
@ -36,6 +36,8 @@ from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import CpType
|
||||
from tensorrt_llm.runtime.generation import CUASSERT
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
|
||||
from tensorrt_llm.tools.profiler.host_profile_tools.host_profiler import \
|
||||
host_profiler_context
|
||||
|
||||
from ..distributed import Distributed
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
@ -560,8 +562,12 @@ class PyExecutor:
|
||||
|
||||
def _event_loop_wrapper(self):
|
||||
try:
|
||||
with customized_gc_thresholds(
|
||||
self.garbage_collection_gen0_threshold):
|
||||
# Skip line profiler during warmup/memory estimation phase to avoid
|
||||
# saving incomplete results that would be overwritten anyway
|
||||
enable_profiler = bool(os.environ.get(
|
||||
"TLLM_LINE_PROFILER_PATH")) and not self.is_warmup
|
||||
with host_profiler_context(enable=enable_profiler), \
|
||||
customized_gc_thresholds(self.garbage_collection_gen0_threshold):
|
||||
self.event_loop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event loop: {e}")
|
||||
|
||||
15
tensorrt_llm/tools/profiler/__init__.py
Normal file
15
tensorrt_llm/tools/profiler/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Profiler tools for TensorRT-LLM."""
|
||||
15
tensorrt_llm/tools/profiler/host_profile_tools/__init__.py
Normal file
15
tensorrt_llm/tools/profiler/host_profile_tools/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Host-side profiling tools for CPU overhead analysis."""
|
||||
|
||||
from .host_profiler import (
|
||||
HostProfiler,
|
||||
get_global_profiler,
|
||||
host_profiler_context,
|
||||
set_global_profiler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HostProfiler",
|
||||
"get_global_profiler",
|
||||
"host_profiler_context",
|
||||
"set_global_profiler",
|
||||
]
|
||||
763
tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py
Normal file
763
tensorrt_llm/tools/profiler/host_profile_tools/host_profiler.py
Normal file
@ -0,0 +1,763 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Host-side profiler for analyzing CPU overhead in the PyExecutor.
|
||||
|
||||
This module provides a HostProfiler class that wraps line_profiler to measure
|
||||
line-by-line execution time of critical functions in the executor worker thread.
|
||||
|
||||
Usage:
|
||||
Set environment variable TLLM_LINE_PROFILER_PATH to enable:
|
||||
TLLM_LINE_PROFILER_PATH=./lp_results.txt pytest ...
|
||||
|
||||
Or use programmatically:
|
||||
profiler = HostProfiler(output_path="./results.txt")
|
||||
profiler.start()
|
||||
# ... run code ...
|
||||
profiler.stop()
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Environment variable to enable line_profiler output path.
|
||||
LINE_PROFILER_PATH_ENV_VAR = "TLLM_LINE_PROFILER_PATH"
|
||||
|
||||
# Environment variable to specify additional functions to profile (comma-separated).
|
||||
# Format: "module.Class.method,module.Class.method2,..."
|
||||
LINE_PROFILER_FUNCTIONS_ENV_VAR = "TLLM_LINE_PROFILER_FUNCTIONS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileTarget:
|
||||
"""Represents a function to be profiled.
|
||||
|
||||
Supports both class methods and standalone module-level functions.
|
||||
|
||||
Examples:
|
||||
Class method:
|
||||
ProfileTarget("module.path", "ClassName", "method_name")
|
||||
-> resolves to module.path.ClassName.method_name
|
||||
|
||||
Standalone function:
|
||||
ProfileTarget("module.path", None, "function_name")
|
||||
-> resolves to module.path.function_name
|
||||
"""
|
||||
|
||||
module_path: str
|
||||
class_name: Optional[str] # None for standalone functions
|
||||
method_name: str
|
||||
|
||||
@property
|
||||
def full_path(self) -> str:
|
||||
if self.class_name is None:
|
||||
return f"{self.module_path}.{self.method_name}"
|
||||
return f"{self.module_path}.{self.class_name}.{self.method_name}"
|
||||
|
||||
@property
|
||||
def is_standalone(self) -> bool:
|
||||
"""Returns True if this is a standalone function (not a class method)."""
|
||||
return self.class_name is None
|
||||
|
||||
def resolve(self) -> Optional[Callable]:
|
||||
"""Resolve the target to an actual function object.
|
||||
|
||||
Returns:
|
||||
The unwrapped method/function (inner function), or None if resolution fails.
|
||||
|
||||
Note:
|
||||
We MUST unwrap decorated functions (e.g., @torch.inference_mode,
|
||||
@nvtx_range) because line_profiler traces by __code__ identity.
|
||||
When a function is wrapped by decorators like @torch.inference_mode(),
|
||||
the wrapper's __code__ is different from the actual function's __code__.
|
||||
The wrapper only contains a few lines that call the inner function,
|
||||
so line_profiler would only see those wrapper lines, not the actual
|
||||
function body we want to profile.
|
||||
|
||||
By unwrapping, we get the actual function's __code__ which allows
|
||||
line_profiler to trace the real function lines.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(self.module_path)
|
||||
|
||||
if self.is_standalone:
|
||||
# Standalone module-level function
|
||||
func = getattr(module, self.method_name)
|
||||
else:
|
||||
# Class method
|
||||
cls = getattr(module, self.class_name)
|
||||
func = getattr(cls, self.method_name)
|
||||
|
||||
# Unwrap decorated functions to get the actual inner function.
|
||||
# This is necessary for @torch.inference_mode(), @nvtx_range(), etc.
|
||||
# Without unwrapping, line_profiler would only trace the wrapper's
|
||||
# __code__, not the actual function body.
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
|
||||
return func
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Failed to resolve profile target {self.full_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Default functions to profile for host overhead analysis
|
||||
# Hierarchical config: {module_path: {class_name: [method_names]}}
|
||||
# Use None as class_name for standalone module-level functions
|
||||
#
|
||||
# Wildcard support:
|
||||
# - ["*"] as method list: Profile all methods of the class
|
||||
# - {None: ["*"]}: Profile all standalone functions in the module
|
||||
# - {"*": ["*"]}: Profile all classes and all their methods in the module
|
||||
_PYEXEC = "tensorrt_llm._torch.pyexecutor"
|
||||
_DEFAULT_PROFILE_CONFIG: Dict[str, Dict[Optional[str], List[str]]] = {
|
||||
f"{_PYEXEC}.py_executor": {
|
||||
"PyExecutor": [
|
||||
"_prepare_and_schedule_batch",
|
||||
"_schedule",
|
||||
"_forward_step",
|
||||
"_sample_async",
|
||||
"_update_requests",
|
||||
"_update_request_states",
|
||||
"_fetch_and_activate_new_requests",
|
||||
"_handle_responses",
|
||||
"_handle_canceled_requests",
|
||||
"_enqueue_responses",
|
||||
],
|
||||
},
|
||||
f"{_PYEXEC}.sampler": {
|
||||
"TorchSampler": [
|
||||
"sample_async",
|
||||
"update_requests",
|
||||
"_process_requests",
|
||||
"_write_finish_reasons",
|
||||
"_prepare_beam_search",
|
||||
"_select_generated_logits",
|
||||
"_sample_batched_by_strategy",
|
||||
],
|
||||
# Standalone module-level functions (use None as class_name)
|
||||
None: [
|
||||
"_group_requests_by_strategy_key",
|
||||
],
|
||||
},
|
||||
f"{_PYEXEC}.resource_manager": {
|
||||
"ResourceManager": ["prepare_resources", "update_resources", "free_resources"],
|
||||
"KVCacheManager": ["prepare_resources", "update_resources"],
|
||||
},
|
||||
f"{_PYEXEC}.scheduler": {
|
||||
"RequestScheduler": ["schedule_request"],
|
||||
},
|
||||
f"{_PYEXEC}.executor_request_queue": {
|
||||
"ExecutorRequestQueue": [
|
||||
"_fetch_new_requests_attention_tp",
|
||||
"_fetch_new_requests_attention_dp",
|
||||
"_fetch_and_process_requests",
|
||||
"_merge_requests",
|
||||
"fetch_new_requests",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_all_methods_from_class(cls: type) -> List[str]:
|
||||
"""Get all user-defined methods from a class (excluding inherited, dunder, nested classes, and properties).
|
||||
|
||||
Only includes items that have actual executable code (__code__ attribute):
|
||||
- Regular instance methods (def foo(self): ...)
|
||||
- Static methods (@staticmethod)
|
||||
- Class methods (@classmethod)
|
||||
|
||||
Excludes:
|
||||
- Nested classes (e.g., dataclasses like Args, Store)
|
||||
- Properties (usually trivial getters, complex to profile)
|
||||
- Constants and type aliases
|
||||
- Dunder methods (__init__, __repr__, etc.)
|
||||
|
||||
Args:
|
||||
cls: The class to introspect.
|
||||
|
||||
Returns:
|
||||
List of method names defined directly on the class.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
methods = []
|
||||
for name, member in cls.__dict__.items():
|
||||
# Skip dunder methods
|
||||
if name.startswith("__") and name.endswith("__"):
|
||||
continue
|
||||
|
||||
# Skip nested classes
|
||||
if inspect.isclass(member):
|
||||
continue
|
||||
|
||||
# Skip properties (usually trivial, complex to handle with line_profiler)
|
||||
if isinstance(member, property):
|
||||
continue
|
||||
|
||||
# Extract the underlying function from descriptors
|
||||
func = None
|
||||
if isinstance(member, staticmethod):
|
||||
func = member.__func__
|
||||
elif isinstance(member, classmethod):
|
||||
func = member.__func__
|
||||
elif inspect.isfunction(member):
|
||||
func = member
|
||||
|
||||
# Only include if we have a valid function with executable code
|
||||
if func is not None and hasattr(func, "__code__"):
|
||||
methods.append(name)
|
||||
|
||||
return methods
|
||||
|
||||
|
||||
def _get_all_functions_from_module(module_path: str) -> List[str]:
|
||||
"""Get all user-defined functions from a module (excluding imported ones).
|
||||
|
||||
Args:
|
||||
module_path: The module path to introspect.
|
||||
|
||||
Returns:
|
||||
List of function names defined directly in the module.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {module_path} for introspection: {e}")
|
||||
return []
|
||||
|
||||
functions = []
|
||||
for name, member in inspect.getmembers(module, inspect.isfunction):
|
||||
# Only include functions defined in this module (not imported)
|
||||
if member.__module__ == module_path:
|
||||
# Skip private functions starting with underscore if desired
|
||||
# For profiling, we likely want to include them
|
||||
functions.append(name)
|
||||
return functions
|
||||
|
||||
|
||||
def _get_all_classes_from_module(module_path: str) -> List[str]:
|
||||
"""Get all user-defined classes from a module (excluding imported ones).
|
||||
|
||||
Args:
|
||||
module_path: The module path to introspect.
|
||||
|
||||
Returns:
|
||||
List of class names defined directly in the module.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {module_path} for introspection: {e}")
|
||||
return []
|
||||
|
||||
classes = []
|
||||
for name, member in inspect.getmembers(module, inspect.isclass):
|
||||
# Only include classes defined in this module (not imported)
|
||||
if member.__module__ == module_path:
|
||||
classes.append(name)
|
||||
return classes
|
||||
|
||||
|
||||
def _expand_profile_config(
|
||||
config: Dict[str, Dict[Optional[str], List[str]]],
|
||||
) -> List[ProfileTarget]:
|
||||
"""Expand hierarchical config into a flat list of ProfileTarget objects.
|
||||
|
||||
Args:
|
||||
config: Hierarchical config mapping module_path -> class_name -> method_names.
|
||||
Use None as class_name for standalone module-level functions.
|
||||
|
||||
Special markers:
|
||||
- "*" in method list: Profile all methods of the class
|
||||
- "*" as class_name key with ["*"]: Profile all classes and their methods
|
||||
- None as class_name key with ["*"]: Profile all standalone functions
|
||||
|
||||
Returns:
|
||||
List of ProfileTarget objects.
|
||||
|
||||
Examples:
|
||||
# Profile all methods of TorchSampler class:
|
||||
{"module.sampler": {"TorchSampler": ["*"]}}
|
||||
|
||||
# Profile all standalone functions in a module:
|
||||
{"module.sampler": {None: ["*"]}}
|
||||
|
||||
# Profile all classes and all their methods in a module:
|
||||
{"module.sampler": {"*": ["*"]}}
|
||||
|
||||
# Mix explicit and wildcard:
|
||||
{"module.sampler": {"TorchSampler": ["*"], "OtherClass": ["specific_method"]}}
|
||||
"""
|
||||
targets = []
|
||||
for module_path, classes in config.items():
|
||||
for class_name, methods in classes.items():
|
||||
# Handle wildcard class_name: "*" means all classes in the module
|
||||
if class_name == "*":
|
||||
if methods == ["*"]:
|
||||
# Profile all methods of all classes in the module
|
||||
all_classes = _get_all_classes_from_module(module_path)
|
||||
for cls_name in all_classes:
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, cls_name)
|
||||
all_methods = _get_all_methods_from_class(cls)
|
||||
for method_name in all_methods:
|
||||
targets.append(ProfileTarget(module_path, cls_name, method_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Failed to introspect {module_path}.{cls_name}: {e}")
|
||||
continue
|
||||
|
||||
# Handle wildcard methods: ["*"] means all methods of the class/module
|
||||
if methods == ["*"]:
|
||||
if class_name is None:
|
||||
# All standalone functions in the module
|
||||
all_funcs = _get_all_functions_from_module(module_path)
|
||||
for func_name in all_funcs:
|
||||
targets.append(ProfileTarget(module_path, None, func_name))
|
||||
else:
|
||||
# All methods of a specific class
|
||||
try:
|
||||
module = importlib.import_module(module_path)
|
||||
cls = getattr(module, class_name)
|
||||
all_methods = _get_all_methods_from_class(cls)
|
||||
for method_name in all_methods:
|
||||
targets.append(ProfileTarget(module_path, class_name, method_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Failed to introspect {module_path}.{class_name}: {e}")
|
||||
continue
|
||||
|
||||
# Normal case: explicit method list
|
||||
for method_name in methods:
|
||||
targets.append(ProfileTarget(module_path, class_name, method_name))
|
||||
return targets
|
||||
|
||||
|
||||
DEFAULT_PROFILE_TARGETS: List[ProfileTarget] = _expand_profile_config(_DEFAULT_PROFILE_CONFIG)
|
||||
|
||||
|
||||
class HostProfiler:
|
||||
"""Host-side profiler for measuring CPU overhead in the executor.
|
||||
|
||||
This class wraps line_profiler to provide line-by-line timing analysis
|
||||
of critical functions in the PyExecutor worker thread.
|
||||
|
||||
Attributes:
|
||||
output_path: Path to save profiling results.
|
||||
targets: List of ProfileTarget objects specifying functions to profile.
|
||||
enabled: Whether profiling is currently active.
|
||||
|
||||
Example:
|
||||
>>> profiler = HostProfiler(output_path="./results.txt")
|
||||
>>> profiler.add_target(
|
||||
... ProfileTarget(
|
||||
... module_path="my_module",
|
||||
... class_name="MyClass",
|
||||
... method_name="my_method",
|
||||
... )
|
||||
... )
|
||||
>>> profiler.start()
|
||||
>>> # ... run code ...
|
||||
>>> profiler.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_path: Optional[str] = None,
|
||||
targets: Optional[List[ProfileTarget]] = None,
|
||||
use_defaults: bool = True,
|
||||
):
|
||||
"""Initialize the host profiler.
|
||||
|
||||
Args:
|
||||
output_path: Path to save results. If None, uses env var TLLM_LINE_PROFILER_PATH.
|
||||
targets: List of ProfileTarget objects. If None and use_defaults=True,
|
||||
uses DEFAULT_PROFILE_TARGETS.
|
||||
use_defaults: Whether to include default profile targets.
|
||||
"""
|
||||
self.output_path = output_path or os.environ.get(LINE_PROFILER_PATH_ENV_VAR)
|
||||
self.targets: List[ProfileTarget] = []
|
||||
self._line_profiler = None
|
||||
self._enabled = False
|
||||
|
||||
# Add default targets if requested
|
||||
if use_defaults:
|
||||
self.targets.extend(DEFAULT_PROFILE_TARGETS)
|
||||
|
||||
# Add custom targets
|
||||
if targets:
|
||||
self.targets.extend(targets)
|
||||
|
||||
# Parse additional targets from environment variable
|
||||
self._parse_env_targets()
|
||||
|
||||
def _parse_env_targets(self) -> None:
|
||||
"""Parse additional profile targets from environment variable.
|
||||
|
||||
Supported formats:
|
||||
- module.Class.method -> class method
|
||||
- module::function -> standalone function (uses :: as delimiter)
|
||||
"""
|
||||
env_funcs = os.environ.get(LINE_PROFILER_FUNCTIONS_ENV_VAR, "")
|
||||
if not env_funcs:
|
||||
return
|
||||
|
||||
for func_path in env_funcs.split(","):
|
||||
func_path = func_path.strip()
|
||||
if not func_path:
|
||||
continue
|
||||
|
||||
# Check for standalone function format: "module::function"
|
||||
if "::" in func_path:
|
||||
parts = func_path.rsplit("::", 1)
|
||||
if len(parts) != 2:
|
||||
logger.warning(
|
||||
f"Invalid standalone function path '{func_path}'. "
|
||||
"Expected format: module.path::function_name"
|
||||
)
|
||||
continue
|
||||
module_path, method_name = parts
|
||||
self.targets.append(
|
||||
ProfileTarget(
|
||||
module_path=module_path,
|
||||
class_name=None, # Standalone function
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Class method format: "module.Class.method"
|
||||
parts = func_path.rsplit(".", 2)
|
||||
if len(parts) < 3:
|
||||
logger.warning(
|
||||
f"Invalid function path '{func_path}'. Expected format: "
|
||||
"module.Class.method (class method) or module::function (standalone)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle nested module paths
|
||||
method_name = parts[-1]
|
||||
class_name = parts[-2]
|
||||
module_path = ".".join(parts[:-2])
|
||||
|
||||
self.targets.append(
|
||||
ProfileTarget(
|
||||
module_path=module_path,
|
||||
class_name=class_name,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
|
||||
def add_target(self, target: ProfileTarget) -> "HostProfiler":
|
||||
"""Add a profile target.
|
||||
|
||||
Args:
|
||||
target: The ProfileTarget to add.
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
"""
|
||||
self.targets.append(target)
|
||||
return self
|
||||
|
||||
def add_function(
|
||||
self,
|
||||
module_path: str,
|
||||
class_name: Optional[str],
|
||||
method_name: str,
|
||||
) -> "HostProfiler":
|
||||
"""Add a function to profile by specifying its path.
|
||||
|
||||
Args:
|
||||
module_path: The module path (e.g., "tensorrt_llm._torch.pyexecutor.sampler")
|
||||
class_name: The class name (e.g., "TorchSampler"), or None for standalone functions
|
||||
method_name: The method/function name (e.g., "_process_requests")
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
|
||||
Examples:
|
||||
# Add a class method
|
||||
profiler.add_function("my.module", "MyClass", "my_method")
|
||||
|
||||
# Add a standalone function
|
||||
profiler.add_function("my.module", None, "my_function")
|
||||
"""
|
||||
return self.add_target(
|
||||
ProfileTarget(
|
||||
module_path=module_path,
|
||||
class_name=class_name,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
|
||||
def add_standalone_function(
|
||||
self,
|
||||
module_path: str,
|
||||
function_name: str,
|
||||
) -> "HostProfiler":
|
||||
"""Add a standalone module-level function to profile.
|
||||
|
||||
This is a convenience method for adding standalone functions
|
||||
(not class methods).
|
||||
|
||||
Args:
|
||||
module_path: The module path (e.g., "tensorrt_llm._torch.pyexecutor.sampler")
|
||||
function_name: The function name (e.g., "_group_requests_by_strategy_key")
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
"""
|
||||
return self.add_target(
|
||||
ProfileTarget(
|
||||
module_path=module_path,
|
||||
class_name=None,
|
||||
method_name=function_name,
|
||||
)
|
||||
)
|
||||
|
||||
def clear_targets(self) -> "HostProfiler":
|
||||
"""Clear all profile targets (including defaults).
|
||||
|
||||
This is useful when you want to start with a clean slate and add
|
||||
only specific targets, or when you want to replace all default
|
||||
targets with a custom set.
|
||||
|
||||
Returns:
|
||||
Self for chaining.
|
||||
|
||||
Example:
|
||||
# Clear defaults and add only specific targets
|
||||
profiler = HostProfiler(use_defaults=True)
|
||||
profiler.clear_targets().add_function("my.module", "MyClass", "method")
|
||||
|
||||
# Or start fresh without defaults
|
||||
profiler = HostProfiler(use_defaults=False)
|
||||
"""
|
||||
self.targets.clear()
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if line_profiler is available."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def should_profile(self) -> bool:
|
||||
"""Check if profiling should be enabled (output path is set)."""
|
||||
return self.output_path is not None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Check if profiling is currently active."""
|
||||
return self._enabled
|
||||
|
||||
def start(self) -> bool:
|
||||
"""Start profiling.
|
||||
|
||||
Returns:
|
||||
True if profiling started successfully, False otherwise.
|
||||
"""
|
||||
if not self.should_profile:
|
||||
logger.info("Line profiler not enabled (no output path specified)")
|
||||
return False
|
||||
|
||||
if not self.is_available:
|
||||
logger.warning("line_profiler not installed. Install with: pip install line_profiler")
|
||||
return False
|
||||
|
||||
if self._enabled:
|
||||
logger.warning("Line profiler already started")
|
||||
return True
|
||||
|
||||
try:
|
||||
from line_profiler import LineProfiler
|
||||
|
||||
self._line_profiler = LineProfiler()
|
||||
|
||||
# Add all target functions
|
||||
resolved_count = 0
|
||||
for target in self.targets:
|
||||
func = target.resolve()
|
||||
if func is not None:
|
||||
logger.info(
|
||||
f"line profiler func code ID: {id(func.__code__)}, target: {target.full_path}"
|
||||
)
|
||||
self._line_profiler.add_function(func)
|
||||
resolved_count += 1
|
||||
|
||||
if resolved_count == 0:
|
||||
logger.warning("No profile targets could be resolved")
|
||||
self._line_profiler = None
|
||||
return False
|
||||
|
||||
self._line_profiler.enable()
|
||||
self._enabled = True
|
||||
self._profiler_thread_id = threading.current_thread().ident
|
||||
logger.info(
|
||||
f"Line profiler enabled with {resolved_count}/{len(self.targets)} targets. "
|
||||
f"Thread ID: {self._profiler_thread_id}, Thread name: {threading.current_thread().name}. "
|
||||
f"Results will be saved to: {self.output_path}"
|
||||
)
|
||||
dump_profiler_functions()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start line profiler: {e}")
|
||||
self._line_profiler = None
|
||||
return False
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""Stop profiling and save results.
|
||||
|
||||
Returns:
|
||||
True if results were saved successfully, False otherwise.
|
||||
"""
|
||||
if not self._enabled or self._line_profiler is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._line_profiler.disable()
|
||||
self._enabled = False
|
||||
|
||||
# Save results
|
||||
with open(self.output_path, "w") as f:
|
||||
self._line_profiler.print_stats(stream=f)
|
||||
|
||||
logger.info(f"Line profiler results saved to: {self.output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save line profiler results: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
self._line_profiler = None
|
||||
|
||||
@contextmanager
|
||||
def profile(self):
|
||||
"""Context manager for profiling.
|
||||
|
||||
Usage:
|
||||
with profiler.profile():
|
||||
# ... code to profile ...
|
||||
"""
|
||||
started = self.start()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
if started:
|
||||
self.stop()
|
||||
|
||||
def get_stats_string(self) -> Optional[str]:
|
||||
"""Get profiling stats as a string (without saving to file).
|
||||
|
||||
Returns:
|
||||
Stats string if profiling is active, None otherwise.
|
||||
"""
|
||||
if self._line_profiler is None:
|
||||
return None
|
||||
|
||||
import io
|
||||
|
||||
stream = io.StringIO()
|
||||
self._line_profiler.print_stats(stream=stream)
|
||||
return stream.getvalue()
|
||||
|
||||
def list_targets(self) -> List[str]:
|
||||
"""List all configured profile targets.
|
||||
|
||||
Returns:
|
||||
List of target paths.
|
||||
"""
|
||||
return [t.full_path for t in self.targets]
|
||||
|
||||
|
||||
# Global profiler instance for use in worker thread
|
||||
_global_profiler: Optional[HostProfiler] = None
|
||||
|
||||
|
||||
def get_global_profiler() -> Optional[HostProfiler]:
|
||||
"""Get the global profiler instance."""
|
||||
return _global_profiler
|
||||
|
||||
|
||||
def dump_profiler_functions() -> None:
|
||||
"""Print all functions registered with the line profiler for debugging.
|
||||
|
||||
Only dumps on rank 0 to avoid interleaved output from multiple ranks.
|
||||
"""
|
||||
# Import here to avoid circular imports and handle cases where MPI is not initialized
|
||||
try:
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
|
||||
if mpi_rank() != 0:
|
||||
return
|
||||
except Exception:
|
||||
pass # If MPI is not available, proceed (single-rank case)
|
||||
|
||||
profiler = get_global_profiler()
|
||||
if profiler is None or profiler._line_profiler is None:
|
||||
logger.info("No line profiler active")
|
||||
return
|
||||
|
||||
lp = profiler._line_profiler
|
||||
logger.info(f"=== Line Profiler State: {len(lp.functions)} functions registered ===")
|
||||
for func in lp.functions:
|
||||
logger.info(f" {func.__module__}.{func.__qualname__}, code id: {id(func.__code__)}")
|
||||
logger.info("=== End Line Profiler State ===")
|
||||
|
||||
|
||||
def set_global_profiler(profiler: Optional[HostProfiler]) -> None:
|
||||
"""Set the global profiler instance."""
|
||||
global _global_profiler
|
||||
_global_profiler = profiler
|
||||
|
||||
|
||||
@contextmanager
|
||||
def host_profiler_context(enable: bool = True, output_path: Optional[str] = None):
|
||||
"""Context manager for host profiling in the worker thread.
|
||||
|
||||
This is the main entry point for profiling in PyExecutor._event_loop_wrapper.
|
||||
|
||||
Args:
|
||||
output_path: Path to save results. If None, uses env var.
|
||||
|
||||
Usage:
|
||||
with host_profiler_context():
|
||||
# ... event loop code ...
|
||||
"""
|
||||
if not enable:
|
||||
yield None
|
||||
return
|
||||
|
||||
profiler = HostProfiler(output_path=output_path)
|
||||
set_global_profiler(profiler)
|
||||
|
||||
started = profiler.start()
|
||||
try:
|
||||
yield profiler
|
||||
finally:
|
||||
if started:
|
||||
profiler.stop()
|
||||
set_global_profiler(None)
|
||||
278
tests/unittest/tools/test_host_profiler.py
Normal file
278
tests/unittest/tools/test_host_profiler.py
Normal file
@ -0,0 +1,278 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for the HostProfiler module.
|
||||
|
||||
Tests cover:
|
||||
1. Core API functionality (add_*, clear_targets, chaining)
|
||||
2. Line profiler integration with report validation
|
||||
3. E2E test with actual model inference
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
# Add path for test utilities
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm.tools.profiler.host_profile_tools.host_profiler import HostProfiler, ProfileTarget
|
||||
|
||||
|
||||
def _sample_function_to_profile(n: int) -> int:
|
||||
"""A simple function with multiple lines to profile."""
|
||||
total = 0
|
||||
for i in range(n):
|
||||
total += i
|
||||
return total
|
||||
|
||||
|
||||
def _another_function(x: int, y: int) -> int:
|
||||
"""Another function for testing."""
|
||||
result = x + y
|
||||
result *= 2
|
||||
return result
|
||||
|
||||
|
||||
class _SampleClassToProfile:
|
||||
"""A sample class with methods to profile."""
|
||||
|
||||
def instance_method(self, n: int) -> int:
|
||||
"""Instance method to profile."""
|
||||
result = 0
|
||||
for i in range(n):
|
||||
result += i * 2
|
||||
return result
|
||||
|
||||
|
||||
def _patch_mpi_pool_session_for_env(mocker, env_vars: dict):
|
||||
"""Patch MpiPoolSession to propagate env vars to MPI workers."""
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
|
||||
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
||||
|
||||
def patched_start_mpi_pool(self):
|
||||
assert not self.mpi_pool, "MPI session already started"
|
||||
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, path=sys.path, env=env_vars)
|
||||
|
||||
mocker.patch.object(MpiPoolSession, "_start_mpi_pool", patched_start_mpi_pool)
|
||||
|
||||
|
||||
def test_add_and_clear_targets():
|
||||
"""Test add_*, clear_targets(), and chaining work correctly."""
|
||||
profiler = HostProfiler(use_defaults=False)
|
||||
|
||||
# Test add_* methods with chaining
|
||||
result = (
|
||||
profiler.add_function("m1", "C1", "f1")
|
||||
.add_standalone_function("m2", "f2")
|
||||
.add_target(ProfileTarget("m3", "C3", "f3"))
|
||||
)
|
||||
|
||||
assert result is profiler
|
||||
assert len(profiler.targets) == 3
|
||||
assert profiler.list_targets() == ["m1.C1.f1", "m2.f2", "m3.C3.f3"]
|
||||
|
||||
# Test clear_targets with chaining
|
||||
result = profiler.clear_targets()
|
||||
assert result is profiler
|
||||
assert len(profiler.targets) == 0
|
||||
|
||||
# Test add after clear
|
||||
profiler.add_function("os", None, "getcwd")
|
||||
assert len(profiler.targets) == 1
|
||||
|
||||
|
||||
def test_defaults_and_clear():
|
||||
"""Test that defaults are loaded and can be cleared."""
|
||||
profiler = HostProfiler(use_defaults=True)
|
||||
initial_count = len(profiler.targets)
|
||||
assert initial_count > 10, "Should have many default targets"
|
||||
|
||||
# Clear all and add custom
|
||||
profiler.clear_targets().add_standalone_function("os", "getcwd").add_standalone_function(
|
||||
"os.path", "join"
|
||||
)
|
||||
|
||||
assert len(profiler.targets) == 2
|
||||
assert "os.getcwd" in profiler.list_targets()
|
||||
assert "os.path.join" in profiler.list_targets()
|
||||
|
||||
|
||||
def test_profiling_cycle_and_report_validation():
|
||||
"""Test complete profiling cycle and validate report format.
|
||||
|
||||
This is the main unit test that validates:
|
||||
1. Profiler starts and stops correctly
|
||||
2. Report contains Timer unit header
|
||||
3. Report contains column headers (Line, Hits, Time)
|
||||
4. Report contains profiled function name
|
||||
5. Report contains timing data
|
||||
6. Only profiled functions appear (not non-profiled ones)
|
||||
"""
|
||||
try:
|
||||
import line_profiler # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("line_profiler not installed")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
# Use clear_targets to ensure only our test functions are profiled
|
||||
profiler = HostProfiler(output_path=output_path, use_defaults=True)
|
||||
profiler.clear_targets().add_standalone_function(
|
||||
__name__, "_sample_function_to_profile"
|
||||
).add_function(__name__, "_SampleClassToProfile", "instance_method")
|
||||
|
||||
assert len(profiler.targets) == 2
|
||||
|
||||
# Start profiling
|
||||
assert profiler.start() is True
|
||||
assert profiler.enabled is True
|
||||
|
||||
# Execute profiled functions
|
||||
sample_obj = _SampleClassToProfile()
|
||||
for _ in range(50):
|
||||
_sample_function_to_profile(10)
|
||||
sample_obj.instance_method(5)
|
||||
# Execute non-profiled function - should NOT appear in output
|
||||
_another_function(3, 4)
|
||||
|
||||
# Stop and save
|
||||
assert profiler.stop() is True
|
||||
assert profiler.enabled is False
|
||||
|
||||
# Validate output file exists
|
||||
assert os.path.exists(output_path)
|
||||
with open(output_path) as f:
|
||||
content = f.read()
|
||||
|
||||
# --- Report Format Validation ---
|
||||
|
||||
# 1. Timer unit header
|
||||
assert "Timer unit:" in content, "Report missing 'Timer unit:' header"
|
||||
|
||||
# 2. Column headers
|
||||
assert "Line" in content, "Report missing 'Line' column header"
|
||||
assert "Hits" in content, "Report missing 'Hits' column header"
|
||||
assert "Time" in content, "Report missing 'Time' column header"
|
||||
|
||||
# 3. Profiled functions appear
|
||||
assert "_sample_function_to_profile" in content, "Profiled function not in output"
|
||||
assert "instance_method" in content, "Profiled method not in output"
|
||||
|
||||
# 4. Non-profiled function should NOT appear
|
||||
assert "_another_function" not in content, "Non-profiled function should NOT be in output"
|
||||
|
||||
# 5. Should have timing data (lines with numbers)
|
||||
lines = content.split("\n")
|
||||
data_lines = [
|
||||
line for line in lines if line.strip() and any(char.isdigit() for char in line)
|
||||
]
|
||||
assert len(data_lines) > 5, "Report should have multiple lines with timing data"
|
||||
|
||||
print("\n=== Report Validation Passed ===")
|
||||
print(f"Output size: {len(content)} bytes")
|
||||
print(f"Data lines: {len(data_lines)}")
|
||||
|
||||
finally:
|
||||
if os.path.exists(output_path):
|
||||
os.unlink(output_path)
|
||||
|
||||
|
||||
# Core PyExecutor methods that are always executed during inference.
|
||||
# These are stable, fundamental methods unlikely to be renamed.
|
||||
E2E_PROFILE_TARGETS = [
|
||||
"tensorrt_llm._torch.pyexecutor.py_executor.PyExecutor._forward_step",
|
||||
"tensorrt_llm._torch.pyexecutor.py_executor.PyExecutor._schedule",
|
||||
"tensorrt_llm._torch.pyexecutor.sampler.TorchSampler.sample_async",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tinyllama_path():
|
||||
"""Get TinyLlama model path."""
|
||||
model_path = llm_models_root() / "llama-models-v2" / "TinyLlama-1.1B-Chat-v1.0"
|
||||
if not model_path.exists():
|
||||
pytest.skip(f"TinyLlama model not found at {model_path}")
|
||||
return str(model_path)
|
||||
|
||||
|
||||
def test_e2e_profiler_with_model(tinyllama_path, mocker):
|
||||
"""E2E test: verify profiler works with actual model inference.
|
||||
|
||||
Clears default profile targets and adds only specific targets,
|
||||
then verifies those targets appear in the report with non-zero timing.
|
||||
"""
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.tools.profiler.host_profile_tools import host_profiler
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "profiler_output.txt")
|
||||
|
||||
# Clear default targets and use only our specific targets
|
||||
mocker.patch.object(host_profiler, "DEFAULT_PROFILE_TARGETS", [])
|
||||
|
||||
# Patch MPI to propagate env vars to workers
|
||||
_patch_mpi_pool_session_for_env(
|
||||
mocker,
|
||||
{
|
||||
"TLLM_LINE_PROFILER_PATH": output_path,
|
||||
"TLLM_LINE_PROFILER_FUNCTIONS": ", ".join(E2E_PROFILE_TARGETS),
|
||||
},
|
||||
)
|
||||
|
||||
with LLM(
|
||||
model=tinyllama_path,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.3),
|
||||
) as llm:
|
||||
# Generate enough tokens to ensure profiled methods are executed
|
||||
prompts = ["Hello, how are you?", "What is AI?"]
|
||||
outputs = llm.generate(prompts, SamplingParams(max_tokens=16, end_id=-1))
|
||||
assert len(outputs) == len(prompts)
|
||||
|
||||
# Validate output file was created
|
||||
assert os.path.exists(output_path), f"Profiler output not created at {output_path}"
|
||||
|
||||
with open(output_path) as f:
|
||||
content = f.read()
|
||||
|
||||
# Validate report format
|
||||
assert "Timer unit:" in content, "Missing Timer unit header"
|
||||
|
||||
# Verify all specified targets were profiled with actual timing data
|
||||
# Format: "Total time: X s" then "File: ..." then "Function: ClassName.method_name at line X"
|
||||
expected_methods = ["_forward_step", "_schedule", "sample_async"]
|
||||
|
||||
for method in expected_methods:
|
||||
# Find block: "Total time: X s" followed by "File: ..." then "Function: ...method_name..."
|
||||
pattern = (
|
||||
rf"Total time:\s*([\d.e+-]+)\s*s\nFile:.*?\nFunction:.*\.{method}\s+at line \d+"
|
||||
)
|
||||
match = re.search(pattern, content)
|
||||
|
||||
assert match, f"Method '{method}' not found in profiler output"
|
||||
|
||||
total_time = float(match.group(1))
|
||||
assert total_time > 0, f"Method '{method}' has zero total time - not actually profiled"
|
||||
|
||||
print("\n=== E2E Passed ===")
|
||||
print(f"Output: {len(content)} bytes")
|
||||
print(f"Verified methods with timing data: {expected_methods}")
|
||||
Loading…
Reference in New Issue
Block a user