[TRTLLM-4500][feat] Add serialization/deserialization options for AutoTuner profiling cache (#7738)

To achieve determinism for the AutoTuner profiling cache, serialization and deserialization are introduced to store the cache on disk in JSON format. Use TLLM_AUTOTUNER_CACHE_PATH to indicate the path where the cache file should be stored:

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-09-29 07:40:51 +08:00 committed by GitHub
parent 563e588e56
commit 28b9a81c58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 286 additions and 71 deletions

View File

@ -1,14 +1,20 @@
import ast
import contextlib
import copy
import inspect
import itertools
import json
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Set, Tuple, Union
import torch
import tensorrt_llm
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
@ -201,10 +207,26 @@ class TunableRunner(ABC):
@contextlib.contextmanager
def autotune(tune_mode: bool = True):
def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
# if cache_path is provided, use the rank-specific file
tune_required = tune_mode
if cache_path is not None:
# check if the rank-specific file exists
cache_path_no_ext = os.path.splitext(cache_path)[0]
cache_path_no_ext_rank = cache_path_no_ext + f".rank{rank}.json"
# if the rank-specific file exists, load it
file_exists = os.path.exists(cache_path_no_ext_rank)
# if the rank-specific file exists, do not enable tuning mode
tune_required = tune_required and not os.path.exists(cache_path)
if file_exists:
logger.info(
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
AutoTuner.get().profiling_cache.load_cache(cache_path_no_ext_rank)
# record the old tuning mode
old_mode = AutoTuner.get().is_tuning_mode
AutoTuner.get().is_tuning_mode = tune_mode
autotune_enabled = tune_mode and not old_mode
AutoTuner.get().is_tuning_mode = tune_required
autotune_enabled = tune_required and not old_mode
if autotune_enabled:
logger.info("[Autotuner] Autotuning process starts ...")
try:
@ -214,6 +236,11 @@ def autotune(tune_mode: bool = True):
if autotune_enabled:
logger.info("[Autotuner] Autotuning process ends")
# save cache
if cache_path is not None:
logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}")
AutoTuner.get().profiling_cache.save_cache(cache_path_no_ext_rank)
@dataclass
class AutoTunerStatistics:
@ -268,37 +295,40 @@ class AutoTunerStatistics:
return stats_str
class AutoTuner:
"""AutoTuner for optimizing TensorRT LLM operations.
class AutoTunerProfilingCache:
"""AutoTunerCache for caching profiling results.
This class handles automatic performance tuning of tensor operations by profiling
different implementations and caching the best performing configurations.
Args:
warmup (int): Number of warmup iterations before profiling (default: 3)
repeat (int): Number of profiling iterations for averaging (default: 10)
stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000)
The profiling cache can be serialized to disk for persistence across sessions:
- Use save_cache() to save the cache after tuning
- Use load_cache() to restore cached results before inference
- JSON format provides human-readable output and cross-platform compatibility
"""
_instance = None
def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
self.repeat = repeat
self.warmup = warmup
self.stream_delay_micro_secs = stream_delay_micro_secs
self.profiling_cache = {}
self.registered_tuning_configs = {}
self.is_tuning_mode = False
def __init__(self):
self.cache = {}
# Add statistics tracking
self.stats = AutoTunerStatistics()
# Cache metadata for local storage and validation
self.lib_version = tensorrt_llm.__version__
self.creation_timestamp = time.time()
# gpu_platform
self.device_name = torch.cuda.get_device_name()
self.device_capability = torch.cuda.get_device_capability()
self.profiling_debug = True
def __setitem__(self, cache_key: Tuple, value: Tuple) -> None:
self.cache[cache_key] = value
@classmethod
def get(cls):
if cls._instance is None:
cls._instance = AutoTuner()
return cls._instance
def __getitem__(self, cache_key: Tuple) -> Tuple:
return self.cache[cache_key]
def __len__(self) -> int:
return len(self.cache)
def clear(self) -> None:
self.cache.clear()
def fallback_entry(self) -> Tuple:
# runner_id = 0, tactic = -1
return 0, -1, float('inf')
def search_cache(
self,
@ -319,12 +349,195 @@ class AutoTuner:
[is_cache_hit, runner_id, tactic, stored_profile]
"""
for r in runners:
if (cache_key := AutoTuner._get_cache_key(
custom_op, r, input_shapes,
tuning_config)) in self.profiling_cache:
return True, *self.profiling_cache[cache_key]
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
tuning_config)) in self.cache:
return True, *self.cache[cache_key]
return False, 0, -1, None
return False, *self.fallback_entry()
def get_cache_key(
self,
custom_op: str,
runner: TunableRunner,
input_shapes: Tuple[torch.Size],
tuning_config: TuningConfig,
) -> Tuple:
return (
custom_op,
runner.__class__.__name__,
hash(runner),
AutoTuner.get()._find_nearest_profile(
input_shapes,
tuning_config.dynamic_tensor_specs,
tuning_config.constraint_specs,
tuning_config.tune_max_num_tokens,
),
)
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
return {k: v for k, v in self.cache.items() if k[0] == custom_op}
def save_cache(self, file_path: Union[str, Path]) -> None:
"""Save the profiling cache to disk in JSON format.
Args:
file_path: Path where to save the cache
Raises:
IOError: If file cannot be written
Note:
The cache is saved in JSON format which provides human-readable output.
Some type information may be lost for complex tactic objects.
"""
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
serializable_cache = self._serialize_cache_to_json()
with open(file_path, 'w') as f:
json.dump(serializable_cache, f, indent=2, default=str)
logger.info(
f"[AutoTuner] Successfully saved cache to {file_path} using JSON format"
)
except Exception as e:
logger.error(f"[AutoTuner] Failed to save cache with JSON: {e}")
raise
def load_cache(self, file_path: Union[str, Path]) -> None:
"""Load the profiling cache from disk in JSON format.
Args:
file_path: Path to the cache file
Raises:
FileNotFoundError: If cache file doesn't exist
IOError: If file cannot be read
Note:
Loading will replace the current cache contents. The cache is loaded
from JSON format.
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Cache file not found: {file_path}")
try:
with open(file_path, 'r') as f:
serializable_cache = json.load(f)
self.cache = self._deserialize_cache_from_json(serializable_cache)
logger.info(
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format"
)
except Exception as e:
logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}")
raise
def _serialize_cache_to_json(self) -> Dict[str, Any]:
"""Convert the profiling cache to a JSON-serializable format.
Returns:
Dictionary that can be serialized to JSON
Note:
This method handles the conversion of complex objects to JSON-compatible
representations. Some type information may be lost in the conversion.
"""
serializable_cache = {
"metadata": {
"lib_version": self.lib_version,
"creation_timestamp": self.creation_timestamp,
"device_name": self.device_name,
"device_capability": self.device_capability,
},
"cache_data": {},
}
for key, value in self.cache.items():
# Convert tuple key to string for JSON compatibility
key_str = str(key)
runner_id, tactic, min_time = value
serializable_cache["cache_data"][key_str] = {
"runner_id": runner_id,
"tactic": tactic,
"min_time": min_time,
}
return serializable_cache
def _deserialize_cache_from_json(
self, serializable_cache: Dict[str, Any]) -> Dict[Tuple, Tuple]:
"""Convert JSON-serialized cache back to the original format.
Args:
serializable_cache: Dictionary loaded from JSON
Returns:
Profiling cache in the original format
Note:
This attempts to reconstruct the original data structures but may not
perfectly preserve all type information, especially for complex tactic objects.
"""
metadata = serializable_cache["metadata"]
self.lib_version = metadata["lib_version"]
self.creation_timestamp = metadata["creation_timestamp"]
self.device_name = metadata["device_name"]
self.device_capability = metadata["device_capability"]
cache = {}
cache_data = serializable_cache["cache_data"]
for key_str, value in cache_data.items():
# Reconstruct the tuple key safely
try:
key = ast.literal_eval(key_str) # Safer than eval()
except (ValueError, SyntaxError):
logger.warning(
f"[AutoTuner] Could not reconstruct cache key: {key_str}")
continue
runner_id = value["runner_id"]
tactic = value["tactic"]
min_time = value["min_time"]
cache[key] = (runner_id, tactic, min_time)
return cache
class AutoTuner:
"""AutoTuner for optimizing TensorRT LLM operations.
This class handles automatic performance tuning of tensor operations by profiling
different implementations and caching the best performing configurations.
Args:
warmup (int): Number of warmup iterations before profiling (default: 3)
repeat (int): Number of profiling iterations for averaging (default: 10)
stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000)
"""
_instance = None
def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
self.repeat = repeat
self.warmup = warmup
self.stream_delay_micro_secs = stream_delay_micro_secs
self.profiling_cache = AutoTunerProfilingCache()
self.is_tuning_mode = False
# Add statistics tracking
self.stats = AutoTunerStatistics()
self.profiling_debug = True
@classmethod
def get(cls):
if cls._instance is None:
cls._instance = AutoTuner()
return cls._instance
def choose_one(
self,
@ -360,7 +573,7 @@ class AutoTuner:
input_shapes = tuple(self._get_input_sizes(inputs))
# Early return if it's not tuning, use cache found one or fallback one
if not self.is_tuning_mode:
is_cache_hit, best_runner_id, best_tactic, stored_profile = self.search_cache(
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)
best_runner = runners[best_runner_id]
# TODO: check the stored runner and tactic can implement this shape here
@ -388,21 +601,21 @@ class AutoTuner:
for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, *_ = self.search_cache(custom_op, runners,
p.get_opt_shapes(),
tuning_config)
is_cache_hit, *_ = self.profiling_cache.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config)
if not is_cache_hit:
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, has_tuning_failure_occured = self._profile_runners(
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
custom_op, runners, tensors, p, tuning_config, **kwargs)
if best_runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = AutoTuner._get_cache_key(
cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], p.get_opt_shapes(),
tuning_config)
# inspect call stack
self.profiling_cache[cache_key] = (best_runner_id,
best_tactic, p)
best_tactic, min_time)
self.stats.tuned_op_successful_configs[
custom_op] = self.stats.tuned_op_successful_configs.get(
custom_op, 0) + 1
@ -430,8 +643,8 @@ class AutoTuner:
# Get the best runner and tactic from cache
# If no valid tactic is found, the fallback runner and tactic will be used
_, runner_id, tactic, _ = self.search_cache(custom_op, runners,
input_shapes, tuning_config)
_, runner_id, tactic, _ = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)
return (runners[runner_id], tactic)
@ -479,9 +692,9 @@ class AutoTuner:
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
self.stats.failed_profiling_count[custom_op].add(
AutoTuner._get_cache_key(custom_op, runner,
profile.get_opt_shapes(),
tuning_config))
self.profiling_cache.get_cache_key(
custom_op, runner, profile.get_opt_shapes(),
tuning_config))
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
# or some runtime error occurs during profiling.
@ -491,7 +704,7 @@ class AutoTuner:
min_time = time_measured
best_runner_id, best_tactic = runner_id, tac
return best_runner_id, best_tactic, has_tuning_failure_occured
return best_runner_id, best_tactic, min_time, has_tuning_failure_occured
def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:
@ -677,20 +890,6 @@ class AutoTuner:
return tuple(tuple(shape) for shape in base_profile)
@classmethod
def _get_cache_key(
cls,
custom_op: str,
runner: TunableRunner,
input_shapes: Tuple[torch.Size],
tuning_config: TuningConfig,
) -> Tuple:
return (custom_op, runner.__class__.__name__, hash(runner),
cls._find_nearest_profile(input_shapes,
tuning_config.dynamic_tensor_specs,
tuning_config.constraint_specs,
tuning_config.tune_max_num_tokens))
def _create_tensor_like(self, origin_tensor: torch.Tensor,
dims: List[Dim]) -> torch.Tensor:
"""Create a new tensor matching the properties of the original tensor.
@ -746,7 +945,8 @@ class AutoTuner:
logger.debug(
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
)
for key, value in self.profiling_cache.items():
runner_id, tactic, profile = value
for key, value in self.profiling_cache.cache.items():
runner_id, tactic, min_time = value
logger.debug(
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic})")
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
)

View File

@ -4,6 +4,7 @@ import functools
import gc
import inspect
import math
import os
import weakref
from abc import ABC, abstractmethod
from contextlib import contextmanager
@ -662,7 +663,10 @@ class PyTorchModelEngine(ModelEngine):
torch.cuda.synchronize()
if self.pytorch_backend_config.enable_autotuner:
with self.no_cuda_graph(), autotune():
# handle multiple rank issue
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
with self.no_cuda_graph(), autotune(cache_path=cache_path,
rank=self.mapping.rank):
result = get_autotune_warmup_request()
with release_batch(result) as batch:
if batch is None:

View File

@ -1,3 +1,5 @@
import os
import tempfile
from typing import Dict, List
import torch
@ -262,13 +264,13 @@ def test_multiple_runners_different_attributes():
# Verify different cache keys are generated
shapes = (x.shape, w.shape)
cache_key_0 = tuner._get_cache_key(
cache_key_0 = tuner.profiling_cache.get_cache_key(
custom_op="test_multiple_runners",
input_shapes=shapes,
runner=runner_0,
tuning_config=tuning_config,
)
cache_key_1 = tuner._get_cache_key(
cache_key_1 = tuner.profiling_cache.get_cache_key(
custom_op="test_multiple_runners",
input_shapes=shapes,
runner=runner_1,
@ -297,16 +299,25 @@ def test_multiple_dynamic_shapes_cache():
# Do tuning with a sample input
x = torch.randn(3, 64)
with autotune():
temp_dir = tempfile.TemporaryDirectory()
with autotune(cache_path=os.path.join(temp_dir.name,
"test_multiple_dynamic_shapes.json")):
tuner = AutoTuner.get()
runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes",
runners, tuning_config, [x, w])
cache_entries = tuner.profiling_cache.get_specific_custom_op(
"test_multiple_dynamic_shapes")
assert len(cache_entries) == 12, \
f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"
# Verify cache size - should have 12 entries (3x4 combinations)
cache_entries = [
k for k in tuner.profiling_cache.keys()
if k[0] == "test_multiple_dynamic_shapes"
]
# We also test the cache serialization and deserialization here.
AutoTuner.get().profiling_cache.clear()
AutoTuner.get().profiling_cache.load_cache(
os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json"))
cache_entries = tuner.profiling_cache.get_specific_custom_op(
"test_multiple_dynamic_shapes")
assert len(cache_entries) == 12, \
f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"