diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index f720e4767c..2d64022898 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -1,14 +1,20 @@ +import ast import contextlib import copy import inspect import itertools +import json +import os +import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import lru_cache +from pathlib import Path from typing import Any, Callable, Dict, List, Set, Tuple, Union import torch +import tensorrt_llm from tensorrt_llm.bindings.internal.runtime import delay_kernel from tensorrt_llm.logger import logger @@ -201,10 +207,26 @@ class TunableRunner(ABC): @contextlib.contextmanager -def autotune(tune_mode: bool = True): +def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0): + # if cache_path is provided, use the rank-specific file + tune_required = tune_mode + if cache_path is not None: + # check if the rank-specific file exists + cache_path_no_ext = os.path.splitext(cache_path)[0] + cache_path_no_ext_rank = cache_path_no_ext + f".rank{rank}.json" + # if the rank-specific file exists, load it + file_exists = os.path.exists(cache_path_no_ext_rank) + # if the rank-specific file exists, do not enable tuning mode + tune_required = tune_required and not os.path.exists(cache_path) + if file_exists: + logger.info( + f"[Autotuner] Loading cache from {cache_path_no_ext_rank}") + AutoTuner.get().profiling_cache.load_cache(cache_path_no_ext_rank) + + # record the old tuning mode old_mode = AutoTuner.get().is_tuning_mode - AutoTuner.get().is_tuning_mode = tune_mode - autotune_enabled = tune_mode and not old_mode + AutoTuner.get().is_tuning_mode = tune_required + autotune_enabled = tune_required and not old_mode if autotune_enabled: logger.info("[Autotuner] Autotuning process starts ...") try: @@ -214,6 +236,11 @@ def autotune(tune_mode: bool = True): if autotune_enabled: logger.info("[Autotuner] Autotuning process ends") + # save cache + if cache_path is not None: + logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}") + AutoTuner.get().profiling_cache.save_cache(cache_path_no_ext_rank) + @dataclass class AutoTunerStatistics: @@ -268,37 +295,40 @@ class AutoTunerStatistics: return stats_str -class AutoTuner: - """AutoTuner for optimizing TensorRT LLM operations. +class AutoTunerProfilingCache: + """AutoTunerCache for caching profiling results. - This class handles automatic performance tuning of tensor operations by profiling - different implementations and caching the best performing configurations. - - Args: - warmup (int): Number of warmup iterations before profiling (default: 3) - repeat (int): Number of profiling iterations for averaging (default: 10) - stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000) + The profiling cache can be serialized to disk for persistence across sessions: + - Use save_cache() to save the cache after tuning + - Use load_cache() to restore cached results before inference + - JSON format provides human-readable output and cross-platform compatibility """ - _instance = None - def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): - self.repeat = repeat - self.warmup = warmup - self.stream_delay_micro_secs = stream_delay_micro_secs - self.profiling_cache = {} - self.registered_tuning_configs = {} - self.is_tuning_mode = False + def __init__(self): + self.cache = {} - # Add statistics tracking - self.stats = AutoTunerStatistics() + # Cache metadata for local storage and validation + self.lib_version = tensorrt_llm.__version__ + self.creation_timestamp = time.time() + # gpu_platform + self.device_name = torch.cuda.get_device_name() + self.device_capability = torch.cuda.get_device_capability() - self.profiling_debug = True + def __setitem__(self, cache_key: Tuple, value: Tuple) -> None: + self.cache[cache_key] = value - @classmethod - def get(cls): - if cls._instance is None: - cls._instance = AutoTuner() - return cls._instance + def __getitem__(self, cache_key: Tuple) -> Tuple: + return self.cache[cache_key] + + def __len__(self) -> int: + return len(self.cache) + + def clear(self) -> None: + self.cache.clear() + + def fallback_entry(self) -> Tuple: + # runner_id = 0, tactic = -1 + return 0, -1, float('inf') def search_cache( self, @@ -319,12 +349,195 @@ class AutoTuner: [is_cache_hit, runner_id, tactic, stored_profile] """ for r in runners: - if (cache_key := AutoTuner._get_cache_key( - custom_op, r, input_shapes, - tuning_config)) in self.profiling_cache: - return True, *self.profiling_cache[cache_key] + if (cache_key := self.get_cache_key(custom_op, r, input_shapes, + tuning_config)) in self.cache: + return True, *self.cache[cache_key] - return False, 0, -1, None + return False, *self.fallback_entry() + + def get_cache_key( + self, + custom_op: str, + runner: TunableRunner, + input_shapes: Tuple[torch.Size], + tuning_config: TuningConfig, + ) -> Tuple: + return ( + custom_op, + runner.__class__.__name__, + hash(runner), + AutoTuner.get()._find_nearest_profile( + input_shapes, + tuning_config.dynamic_tensor_specs, + tuning_config.constraint_specs, + tuning_config.tune_max_num_tokens, + ), + ) + + def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]: + return {k: v for k, v in self.cache.items() if k[0] == custom_op} + + def save_cache(self, file_path: Union[str, Path]) -> None: + """Save the profiling cache to disk in JSON format. + + Args: + file_path: Path where to save the cache + + Raises: + IOError: If file cannot be written + + Note: + The cache is saved in JSON format which provides human-readable output. + Some type information may be lost for complex tactic objects. + """ + file_path = Path(file_path) + file_path.parent.mkdir(parents=True, exist_ok=True) + + try: + serializable_cache = self._serialize_cache_to_json() + with open(file_path, 'w') as f: + json.dump(serializable_cache, f, indent=2, default=str) + logger.info( + f"[AutoTuner] Successfully saved cache to {file_path} using JSON format" + ) + except Exception as e: + logger.error(f"[AutoTuner] Failed to save cache with JSON: {e}") + raise + + def load_cache(self, file_path: Union[str, Path]) -> None: + """Load the profiling cache from disk in JSON format. + + Args: + file_path: Path to the cache file + + Raises: + FileNotFoundError: If cache file doesn't exist + IOError: If file cannot be read + + Note: + Loading will replace the current cache contents. The cache is loaded + from JSON format. + """ + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"Cache file not found: {file_path}") + + try: + with open(file_path, 'r') as f: + serializable_cache = json.load(f) + self.cache = self._deserialize_cache_from_json(serializable_cache) + logger.info( + f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format" + ) + except Exception as e: + logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}") + raise + + def _serialize_cache_to_json(self) -> Dict[str, Any]: + """Convert the profiling cache to a JSON-serializable format. + + Returns: + Dictionary that can be serialized to JSON + + Note: + This method handles the conversion of complex objects to JSON-compatible + representations. Some type information may be lost in the conversion. + """ + serializable_cache = { + "metadata": { + "lib_version": self.lib_version, + "creation_timestamp": self.creation_timestamp, + "device_name": self.device_name, + "device_capability": self.device_capability, + }, + "cache_data": {}, + } + + for key, value in self.cache.items(): + # Convert tuple key to string for JSON compatibility + key_str = str(key) + + runner_id, tactic, min_time = value + + serializable_cache["cache_data"][key_str] = { + "runner_id": runner_id, + "tactic": tactic, + "min_time": min_time, + } + + return serializable_cache + + def _deserialize_cache_from_json( + self, serializable_cache: Dict[str, Any]) -> Dict[Tuple, Tuple]: + """Convert JSON-serialized cache back to the original format. + + Args: + serializable_cache: Dictionary loaded from JSON + + Returns: + Profiling cache in the original format + + Note: + This attempts to reconstruct the original data structures but may not + perfectly preserve all type information, especially for complex tactic objects. + """ + metadata = serializable_cache["metadata"] + self.lib_version = metadata["lib_version"] + self.creation_timestamp = metadata["creation_timestamp"] + self.device_name = metadata["device_name"] + self.device_capability = metadata["device_capability"] + + cache = {} + cache_data = serializable_cache["cache_data"] + + for key_str, value in cache_data.items(): + # Reconstruct the tuple key safely + try: + key = ast.literal_eval(key_str) # Safer than eval() + except (ValueError, SyntaxError): + logger.warning( + f"[AutoTuner] Could not reconstruct cache key: {key_str}") + continue + + runner_id = value["runner_id"] + tactic = value["tactic"] + min_time = value["min_time"] + + cache[key] = (runner_id, tactic, min_time) + + return cache + + +class AutoTuner: + """AutoTuner for optimizing TensorRT LLM operations. + + This class handles automatic performance tuning of tensor operations by profiling + different implementations and caching the best performing configurations. + + Args: + warmup (int): Number of warmup iterations before profiling (default: 3) + repeat (int): Number of profiling iterations for averaging (default: 10) + stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000) + """ + _instance = None + + def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): + self.repeat = repeat + self.warmup = warmup + self.stream_delay_micro_secs = stream_delay_micro_secs + self.profiling_cache = AutoTunerProfilingCache() + self.is_tuning_mode = False + + # Add statistics tracking + self.stats = AutoTunerStatistics() + + self.profiling_debug = True + + @classmethod + def get(cls): + if cls._instance is None: + cls._instance = AutoTuner() + return cls._instance def choose_one( self, @@ -360,7 +573,7 @@ class AutoTuner: input_shapes = tuple(self._get_input_sizes(inputs)) # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: - is_cache_hit, best_runner_id, best_tactic, stored_profile = self.search_cache( + is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache( custom_op, runners, input_shapes, tuning_config) best_runner = runners[best_runner_id] # TODO: check the stored runner and tactic can implement this shape here @@ -388,21 +601,21 @@ class AutoTuner: for p in profiles: tensors = self._prepare_input_tensors(p, inputs) - is_cache_hit, *_ = self.search_cache(custom_op, runners, - p.get_opt_shapes(), - tuning_config) + is_cache_hit, *_ = self.profiling_cache.search_cache( + custom_op, runners, p.get_opt_shapes(), tuning_config) if not is_cache_hit: # Initialize runner and tactic as None in case of no valid tactic or runners are found - best_runner_id, best_tactic, has_tuning_failure_occured = self._profile_runners( + best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners( custom_op, runners, tensors, p, tuning_config, **kwargs) if best_runner_id is not None: # At least one valid (runner, tactic) pair is found - cache_key = AutoTuner._get_cache_key( + cache_key = self.profiling_cache.get_cache_key( custom_op, runners[best_runner_id], p.get_opt_shapes(), tuning_config) # inspect call stack self.profiling_cache[cache_key] = (best_runner_id, - best_tactic, p) + best_tactic, min_time) + self.stats.tuned_op_successful_configs[ custom_op] = self.stats.tuned_op_successful_configs.get( custom_op, 0) + 1 @@ -430,8 +643,8 @@ class AutoTuner: # Get the best runner and tactic from cache # If no valid tactic is found, the fallback runner and tactic will be used - _, runner_id, tactic, _ = self.search_cache(custom_op, runners, - input_shapes, tuning_config) + _, runner_id, tactic, _ = self.profiling_cache.search_cache( + custom_op, runners, input_shapes, tuning_config) return (runners[runner_id], tactic) @@ -479,9 +692,9 @@ class AutoTuner: if custom_op not in self.stats.failed_profiling_count: self.stats.failed_profiling_count[custom_op] = set() self.stats.failed_profiling_count[custom_op].add( - AutoTuner._get_cache_key(custom_op, runner, - profile.get_opt_shapes(), - tuning_config)) + self.profiling_cache.get_cache_key( + custom_op, runner, profile.get_opt_shapes(), + tuning_config)) # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics # or some runtime error occurs during profiling. @@ -491,7 +704,7 @@ class AutoTuner: min_time = time_measured best_runner_id, best_tactic = runner_id, tac - return best_runner_id, best_tactic, has_tuning_failure_occured + return best_runner_id, best_tactic, min_time, has_tuning_failure_occured def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: @@ -677,20 +890,6 @@ class AutoTuner: return tuple(tuple(shape) for shape in base_profile) - @classmethod - def _get_cache_key( - cls, - custom_op: str, - runner: TunableRunner, - input_shapes: Tuple[torch.Size], - tuning_config: TuningConfig, - ) -> Tuple: - return (custom_op, runner.__class__.__name__, hash(runner), - cls._find_nearest_profile(input_shapes, - tuning_config.dynamic_tensor_specs, - tuning_config.constraint_specs, - tuning_config.tune_max_num_tokens)) - def _create_tensor_like(self, origin_tensor: torch.Tensor, dims: List[Dim]) -> torch.Tensor: """Create a new tensor matching the properties of the original tensor. @@ -746,7 +945,8 @@ class AutoTuner: logger.debug( f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))" ) - for key, value in self.profiling_cache.items(): - runner_id, tactic, profile = value + for key, value in self.profiling_cache.cache.items(): + runner_id, tactic, min_time = value logger.debug( - f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})") + f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})" + ) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index bcd95020bb..a0f650f6af 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4,6 +4,7 @@ import functools import gc import inspect import math +import os import weakref from abc import ABC, abstractmethod from contextlib import contextmanager @@ -662,7 +663,10 @@ class PyTorchModelEngine(ModelEngine): torch.cuda.synchronize() if self.pytorch_backend_config.enable_autotuner: - with self.no_cuda_graph(), autotune(): + # handle multiple rank issue + cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None) + with self.no_cuda_graph(), autotune(cache_path=cache_path, + rank=self.mapping.rank): result = get_autotune_warmup_request() with release_batch(result) as batch: if batch is None: diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index 5ed816df8d..fbdd6de74e 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -1,3 +1,5 @@ +import os +import tempfile from typing import Dict, List import torch @@ -262,13 +264,13 @@ def test_multiple_runners_different_attributes(): # Verify different cache keys are generated shapes = (x.shape, w.shape) - cache_key_0 = tuner._get_cache_key( + cache_key_0 = tuner.profiling_cache.get_cache_key( custom_op="test_multiple_runners", input_shapes=shapes, runner=runner_0, tuning_config=tuning_config, ) - cache_key_1 = tuner._get_cache_key( + cache_key_1 = tuner.profiling_cache.get_cache_key( custom_op="test_multiple_runners", input_shapes=shapes, runner=runner_1, @@ -297,16 +299,25 @@ def test_multiple_dynamic_shapes_cache(): # Do tuning with a sample input x = torch.randn(3, 64) - with autotune(): + temp_dir = tempfile.TemporaryDirectory() + with autotune(cache_path=os.path.join(temp_dir.name, + "test_multiple_dynamic_shapes.json")): tuner = AutoTuner.get() runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes", runners, tuning_config, [x, w]) + cache_entries = tuner.profiling_cache.get_specific_custom_op( + "test_multiple_dynamic_shapes") + assert len(cache_entries) == 12, \ + f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}" # Verify cache size - should have 12 entries (3x4 combinations) - cache_entries = [ - k for k in tuner.profiling_cache.keys() - if k[0] == "test_multiple_dynamic_shapes" - ] + # We also test the cache serialization and deserialization here. + AutoTuner.get().profiling_cache.clear() + AutoTuner.get().profiling_cache.load_cache( + os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json")) + cache_entries = tuner.profiling_cache.get_specific_custom_op( + "test_multiple_dynamic_shapes") + assert len(cache_entries) == 12, \ f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"