[TRTLLM-10308][feat] AutoTuner Cache: reorganize cache file for distributed tuning (#10956)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-01-27 16:39:40 +08:00 committed by GitHub
parent d6f76d2fae
commit b575184fca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 171 additions and 18 deletions

View File

@ -356,11 +356,23 @@ class AutoTunerProfilingCache:
- Use save_cache() to save the cache after tuning
- Use load_cache() to restore cached results before inference
- JSON format provides human-readable output and cross-platform compatibility
Cache organization:
- Ops with INDEPENDENT strategy are stored per-rank (rank_0, rank_1, ...)
- Ops with non-INDEPENDENT strategy (BROADCAST, MERGE, PARALLEL) are stored
in a shared dict since all ranks share the same tuning results
"""
# Key for shared cache entries (non-INDEPENDENT ops)
SHARED_CACHE_KEY = "shared"
def __init__(self):
self.cache: Dict[Tuple, Tuple] = dict()
# Track which ops use which distributed strategy
# Maps custom_op name -> DistributedTuningStrategy
self.independent_op: Set[str] = set()
# Cache metadata for local storage and validation
self.lib_version = tensorrt_llm.__version__
self.creation_timestamp = time.time()
@ -379,6 +391,7 @@ class AutoTunerProfilingCache:
def clear(self) -> None:
self.cache.clear()
self.independent_op.clear()
def fallback_entry(self) -> Tuple:
# runner_id = 0, tactic = -1
@ -443,11 +456,42 @@ class AutoTunerProfilingCache:
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
return {k: v for k, v in self.cache.items() if k[0] == custom_op}
def add_independent_op(self, custom_op: str,
strategy: DistributedTuningStrategy) -> None:
if strategy != DistributedTuningStrategy.INDEPENDENT:
self.independent_op.add(custom_op)
def _partition_cache_by_strategy(
self) -> Tuple[Dict[Tuple, Tuple], Dict[Tuple, Tuple]]:
"""Partition cache entries into shared and rank-specific caches.
Returns:
A tuple of (shared_cache, rank_cache) where:
- shared_cache: entries for non-INDEPENDENT ops (BROADCAST, MERGE, PARALLEL)
- rank_cache: entries for INDEPENDENT ops
"""
shared_cache = {}
rank_cache = {}
for key, value in self.cache.items():
custom_op = key[0] # First element of cache key is custom_op name
if custom_op not in self.independent_op:
rank_cache[key] = value
else:
shared_cache[key] = value
return shared_cache, rank_cache
def save_cache(self, file_path: Union[str, Path], rank: int) -> None:
"""Save the profiling cache to disk in JSON format.
Cache entries are organized based on distributed strategy:
- INDEPENDENT ops are saved per-rank (rank_0, rank_1, ...)
- Non-INDEPENDENT ops (BROADCAST, MERGE, PARALLEL) are saved in a shared dict
Args:
file_path: Path where to save the cache
rank: The rank of the current process
Raises:
IOError: If file cannot be written
@ -460,7 +504,12 @@ class AutoTunerProfilingCache:
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
serialized_rank_cache_data = self._serialize_cache_data()
# Partition cache into shared (non-INDEPENDENT) and rank-specific (INDEPENDENT)
shared_cache, rank_cache = self._partition_cache_by_strategy()
serialized_shared_cache = self._serialize_cache_data(shared_cache)
serialized_rank_cache = self._serialize_cache_data(rank_cache)
with open(file_path, 'a+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.seek(0)
@ -473,7 +522,16 @@ class AutoTunerProfilingCache:
}
f.seek(0)
f.truncate()
current_cache[f"rank_{rank}"] = serialized_rank_cache_data
# Merge shared cache entries (non-INDEPENDENT ops)
if self.SHARED_CACHE_KEY not in current_cache:
current_cache[self.SHARED_CACHE_KEY] = {}
current_cache[self.SHARED_CACHE_KEY].update(
serialized_shared_cache)
# Save rank-specific cache entries (INDEPENDENT ops)
current_cache[f"rank_{rank}"] = serialized_rank_cache
json.dump(current_cache, f, indent=2, default=str)
logger.info(
f"[AutoTuner] Successfully saved cache to {file_path} using JSON format"
@ -485,8 +543,12 @@ class AutoTunerProfilingCache:
def load_cache(self, file_path: Union[str, Path], rank: int) -> None:
"""Load the profiling cache from disk in JSON format.
Loads both shared cache entries (non-INDEPENDENT ops) and rank-specific
entries (INDEPENDENT ops) and merges them into the current cache.
Args:
file_path: Path to the cache file
rank: The rank of the current process
Raises:
FileNotFoundError: If cache file doesn't exist
@ -505,10 +567,39 @@ class AutoTunerProfilingCache:
fcntl.flock(f, fcntl.LOCK_SH)
current_cache_contents = json.load(f)
self._deserialize_metadata(current_cache_contents["metadata"])
self.cache = self._deserialize_cache_data(
current_cache_contents.get(f'rank_{rank}', {}))
# Start with empty cache and independent ops set
self.cache = {}
self.independent_op = set()
# Load shared cache entries (non-INDEPENDENT ops)
if self.SHARED_CACHE_KEY in current_cache_contents:
shared_cache = self._deserialize_cache_data(
current_cache_contents[self.SHARED_CACHE_KEY])
self.cache.update(shared_cache)
# add custom op in shared cache to independent ops set
for key in shared_cache.keys():
self.independent_op.add(key[0])
logger.debug(
f"[AutoTuner] Loaded {len(shared_cache)} shared cache entries"
)
# Load rank-specific cache entries (INDEPENDENT ops)
rank_key = f"rank_{rank}"
if rank_key in current_cache_contents:
rank_cache = self._deserialize_cache_data(
current_cache_contents[rank_key])
self.cache.update(rank_cache)
logger.debug(
f"[AutoTuner] Loaded {len(rank_cache)} rank-specific cache entries for rank {rank}"
)
logger.info(
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format"
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format (total {len(self.cache)} entries)"
)
logger.info(
f"[AutoTuner] independent_op: {type(self.independent_op) if hasattr(self, 'independent_op') else 'not found'}"
)
except Exception as e:
logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}")
@ -528,9 +619,14 @@ class AutoTunerProfilingCache:
self.device_name = metadata["device_name"]
self.device_capability = metadata["device_capability"]
def _serialize_cache_data(self) -> Dict[str, Any]:
def _serialize_cache_data(self,
cache: Optional[Dict[Tuple, Tuple]] = None
) -> Dict[str, Any]:
"""Convert the profiling cache to a JSON-serializable format.
Args:
cache: Optional cache dict to serialize. If None, uses self.cache.
Returns:
Dictionary that can be serialized to JSON
@ -538,9 +634,12 @@ class AutoTunerProfilingCache:
This method handles the conversion of complex objects to JSON-compatible
representations. Some type information may be lost in the conversion.
"""
if cache is None:
cache = self.cache
serializable_cache = {}
for key, value in self.cache.items():
for key, value in cache.items():
# Convert any simple object to string for JSON compatibility
key_str = str(key)
runner_id, tactic, min_time = value
@ -636,8 +735,6 @@ class AutoTuner:
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
# Dsitributed tuning state
self._map_op_to_distributed_strategy: Dict[
str, DistributedTuningStrategy] = {}
self._dist: Optional[Distributed] = None
self._has_received_cache: bool = False
self.mapping: Mapping = Mapping()
@ -844,8 +941,8 @@ class AutoTuner:
"All Given runners must be subclass of TunableRunner"
# Record the distributed tuning strategy for the custom_op
self._map_op_to_distributed_strategy[
custom_op] = tuning_config.distributed_tuning_strategy
self.profiling_cache.add_independent_op(
custom_op, strategy=tuning_config.distributed_tuning_strategy)
tuning_start_time = time.perf_counter()
profiles = self._optimization_profiles(tuning_config, inputs)

View File

@ -1,4 +1,5 @@
import itertools
import json
import os
import pickle
import sys
@ -736,7 +737,10 @@ def _distributed_worker_function(world_size, strategy):
# Each rank prefers different tactics
prefer_tactics = [rank]
runner = DistributedGemmRunner(prefer_tactics=prefer_tactics)
runner_independent = DistributedGemmRunner()
config = TuningConfig(distributed_tuning_strategy=strategy)
config_independent = TuningConfig(
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT)
# Keep temp_dir in function scope to prevent premature garbage collection
temp_dir = None
@ -749,40 +753,92 @@ def _distributed_worker_function(world_size, strategy):
cache_path = dist.broadcast(None, root=0)
with autotune(cache_path=cache_path):
tuner.choose_one(custom_op=f"test_distributed_{strategy}",
tuner.choose_one(custom_op=f"test_distributed_{strategy.value}",
runners=[runner],
tuning_config=config,
inputs=inputs)
# run another normal gemm with INDEPENDENT strategy
tuner.choose_one(custom_op=f"test_distributed_normal_gemm",
runners=[runner_independent],
tuning_config=config_independent,
inputs=inputs)
# Check only one file is created in the cache path
assert len(os.listdir(os.path.dirname(
cache_path))) == 1, "Only one rank file should be created"
dist.barrier()
# Check cache for distributed tuning
AutoTuner.get().profiling_cache.clear()
AutoTuner.get().profiling_cache.load_cache(cache_path, rank)
selected_runner, best_tactic = tuner.choose_one(
custom_op=f"test_distributed_{strategy}",
custom_op=f"test_distributed_{strategy.value}",
runners=[runner],
tuning_config=config,
inputs=inputs)
# Verify cache file structure based on distributed strategy
with open(cache_path, 'r') as f:
cache_data = json.load(f)
# Helper to check if an op name appears in any cache key string
def has_op_in_section(section_data: dict, op_name: str) -> bool:
return any(op_name in key_str for key_str in section_data.keys())
assert 'metadata' in cache_data, "Metadata should be present"
assert f'rank_{rank}' in cache_data, f"rank {rank} should be present"
# The INDEPENDENT op "test_distributed_normal_gemm" should always be in rank-specific sections
assert has_op_in_section(cache_data[f'rank_{rank}'], 'test_distributed_normal_gemm'), \
f"rank {rank} should have test_distributed_normal_gemm"
if strategy == DistributedTuningStrategy.INDEPENDENT:
# Both ops use INDEPENDENT strategy, so no shared section
assert 'shared' not in cache_data or len(cache_data.get('shared', {})) == 0, \
"shared should not be present or be empty for INDEPENDENT strategy"
# Each rank should have 2 entries (the parameterized op + normal_gemm)
assert len(cache_data[f'rank_{rank}']) == 2, \
f"rank {rank} should have 2 entries, got {len(cache_data[f'rank_{rank}'])}"
assert has_op_in_section(cache_data[f'rank_{rank}'], f'test_distributed_{strategy.value}'), \
f"rank {rank} should have test_distributed_{strategy.value}"
assert len(
AutoTuner.get().profiling_cache.independent_op
) == 0, f"Non-INDEPENDENT ops should not be present in the cache"
else:
# Non-INDEPENDENT ops go to shared section
assert 'shared' in cache_data, "shared section should be present"
# Each rank should have only 1 entry (the normal_gemm with INDEPENDENT strategy)
assert len(cache_data[f'rank_{rank}']) == 1, \
f"rank {rank} should have 1 entry, got {len(cache_data[f'rank_{rank}'])}"
# The parameterized op should NOT be in rank-specific section
assert not has_op_in_section(cache_data[f'rank_{rank}'], f'test_distributed_{strategy.value}'), \
f"rank {rank} should not have test_distributed_{strategy.value}"
# The parameterized op should be in shared section
assert has_op_in_section(cache_data['shared'], f'test_distributed_{strategy.value}'), \
f"shared should have test_distributed_{strategy.value}"
assert "test_distributed_normal_gemm" not in AutoTuner.get().profiling_cache.independent_op and \
f"test_distributed_{strategy.value}" in AutoTuner.get().profiling_cache.independent_op, \
f"Distributed tuning strategy is not recovered correctly from cache"
if strategy == DistributedTuningStrategy.BROADCAST:
# All ranks should select tactic 0
assert best_tactic == 0
assert best_tactic == 0, f"Rank {rank} with {strategy} should select tactic 0, got {best_tactic}"
elif strategy == DistributedTuningStrategy.INDEPENDENT:
# Each rank should select the tactic it prefers
assert best_tactic == rank
assert best_tactic == rank, f"Rank {rank} with {strategy} should select tactic {rank}, got {best_tactic}"
elif strategy == DistributedTuningStrategy.MERGE:
# Because tactic 0 is slower, two ranks should always select tactic 1
assert best_tactic == 1
assert best_tactic == 1, f"Rank {rank} with {strategy} should select tactic 1, got {best_tactic}"
elif strategy == DistributedTuningStrategy.PARALLEL:
# Tactic 1 or 3 should be selected since they are faster.
# TODO: This might not cover the case that rank1 tunes nothing
assert best_tactic % 2 == 1
assert best_tactic % 2 == 1, f"Rank {rank} with {strategy} should select tactic 1, got {best_tactic}"
else:
assert False, f"Unknown strategy: {strategy}"
assert False, f"Rank {rank} got unknown strategy: {strategy}"
dist.barrier()
return True