From b575184fca875b0a768d49bd3db592e4cd60ace7 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:39:40 +0800 Subject: [PATCH] [TRTLLM-10308][feat] AutoTuner Cache: reorganize cache file for distributed tuning (#10956) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/autotuner.py | 119 +++++++++++++++++-- tests/unittest/_torch/misc/test_autotuner.py | 70 +++++++++-- 2 files changed, 171 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 89fb89a054..77e2682bf1 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -356,11 +356,23 @@ class AutoTunerProfilingCache: - Use save_cache() to save the cache after tuning - Use load_cache() to restore cached results before inference - JSON format provides human-readable output and cross-platform compatibility + + Cache organization: + - Ops with INDEPENDENT strategy are stored per-rank (rank_0, rank_1, ...) + - Ops with non-INDEPENDENT strategy (BROADCAST, MERGE, PARALLEL) are stored + in a shared dict since all ranks share the same tuning results """ + # Key for shared cache entries (non-INDEPENDENT ops) + SHARED_CACHE_KEY = "shared" + def __init__(self): self.cache: Dict[Tuple, Tuple] = dict() + # Track which ops use which distributed strategy + # Maps custom_op name -> DistributedTuningStrategy + self.independent_op: Set[str] = set() + # Cache metadata for local storage and validation self.lib_version = tensorrt_llm.__version__ self.creation_timestamp = time.time() @@ -379,6 +391,7 @@ class AutoTunerProfilingCache: def clear(self) -> None: self.cache.clear() + self.independent_op.clear() def fallback_entry(self) -> Tuple: # runner_id = 0, tactic = -1 @@ -443,11 +456,42 @@ class AutoTunerProfilingCache: def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]: return {k: v for k, v in self.cache.items() if k[0] == custom_op} + def add_independent_op(self, custom_op: str, + strategy: DistributedTuningStrategy) -> None: + if strategy != DistributedTuningStrategy.INDEPENDENT: + self.independent_op.add(custom_op) + + def _partition_cache_by_strategy( + self) -> Tuple[Dict[Tuple, Tuple], Dict[Tuple, Tuple]]: + """Partition cache entries into shared and rank-specific caches. + + Returns: + A tuple of (shared_cache, rank_cache) where: + - shared_cache: entries for non-INDEPENDENT ops (BROADCAST, MERGE, PARALLEL) + - rank_cache: entries for INDEPENDENT ops + """ + shared_cache = {} + rank_cache = {} + + for key, value in self.cache.items(): + custom_op = key[0] # First element of cache key is custom_op name + if custom_op not in self.independent_op: + rank_cache[key] = value + else: + shared_cache[key] = value + + return shared_cache, rank_cache + def save_cache(self, file_path: Union[str, Path], rank: int) -> None: """Save the profiling cache to disk in JSON format. + Cache entries are organized based on distributed strategy: + - INDEPENDENT ops are saved per-rank (rank_0, rank_1, ...) + - Non-INDEPENDENT ops (BROADCAST, MERGE, PARALLEL) are saved in a shared dict + Args: file_path: Path where to save the cache + rank: The rank of the current process Raises: IOError: If file cannot be written @@ -460,7 +504,12 @@ class AutoTunerProfilingCache: file_path.parent.mkdir(parents=True, exist_ok=True) try: - serialized_rank_cache_data = self._serialize_cache_data() + # Partition cache into shared (non-INDEPENDENT) and rank-specific (INDEPENDENT) + shared_cache, rank_cache = self._partition_cache_by_strategy() + + serialized_shared_cache = self._serialize_cache_data(shared_cache) + serialized_rank_cache = self._serialize_cache_data(rank_cache) + with open(file_path, 'a+') as f: fcntl.flock(f, fcntl.LOCK_EX) f.seek(0) @@ -473,7 +522,16 @@ class AutoTunerProfilingCache: } f.seek(0) f.truncate() - current_cache[f"rank_{rank}"] = serialized_rank_cache_data + + # Merge shared cache entries (non-INDEPENDENT ops) + if self.SHARED_CACHE_KEY not in current_cache: + current_cache[self.SHARED_CACHE_KEY] = {} + current_cache[self.SHARED_CACHE_KEY].update( + serialized_shared_cache) + + # Save rank-specific cache entries (INDEPENDENT ops) + current_cache[f"rank_{rank}"] = serialized_rank_cache + json.dump(current_cache, f, indent=2, default=str) logger.info( f"[AutoTuner] Successfully saved cache to {file_path} using JSON format" @@ -485,8 +543,12 @@ class AutoTunerProfilingCache: def load_cache(self, file_path: Union[str, Path], rank: int) -> None: """Load the profiling cache from disk in JSON format. + Loads both shared cache entries (non-INDEPENDENT ops) and rank-specific + entries (INDEPENDENT ops) and merges them into the current cache. + Args: file_path: Path to the cache file + rank: The rank of the current process Raises: FileNotFoundError: If cache file doesn't exist @@ -505,10 +567,39 @@ class AutoTunerProfilingCache: fcntl.flock(f, fcntl.LOCK_SH) current_cache_contents = json.load(f) self._deserialize_metadata(current_cache_contents["metadata"]) - self.cache = self._deserialize_cache_data( - current_cache_contents.get(f'rank_{rank}', {})) + + # Start with empty cache and independent ops set + self.cache = {} + self.independent_op = set() + + # Load shared cache entries (non-INDEPENDENT ops) + if self.SHARED_CACHE_KEY in current_cache_contents: + shared_cache = self._deserialize_cache_data( + current_cache_contents[self.SHARED_CACHE_KEY]) + self.cache.update(shared_cache) + # add custom op in shared cache to independent ops set + for key in shared_cache.keys(): + self.independent_op.add(key[0]) + logger.debug( + f"[AutoTuner] Loaded {len(shared_cache)} shared cache entries" + ) + + # Load rank-specific cache entries (INDEPENDENT ops) + rank_key = f"rank_{rank}" + if rank_key in current_cache_contents: + rank_cache = self._deserialize_cache_data( + current_cache_contents[rank_key]) + self.cache.update(rank_cache) + logger.debug( + f"[AutoTuner] Loaded {len(rank_cache)} rank-specific cache entries for rank {rank}" + ) + logger.info( - f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format" + f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format (total {len(self.cache)} entries)" + ) + + logger.info( + f"[AutoTuner] independent_op: {type(self.independent_op) if hasattr(self, 'independent_op') else 'not found'}" ) except Exception as e: logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}") @@ -528,9 +619,14 @@ class AutoTunerProfilingCache: self.device_name = metadata["device_name"] self.device_capability = metadata["device_capability"] - def _serialize_cache_data(self) -> Dict[str, Any]: + def _serialize_cache_data(self, + cache: Optional[Dict[Tuple, Tuple]] = None + ) -> Dict[str, Any]: """Convert the profiling cache to a JSON-serializable format. + Args: + cache: Optional cache dict to serialize. If None, uses self.cache. + Returns: Dictionary that can be serialized to JSON @@ -538,9 +634,12 @@ class AutoTunerProfilingCache: This method handles the conversion of complex objects to JSON-compatible representations. Some type information may be lost in the conversion. """ + if cache is None: + cache = self.cache + serializable_cache = {} - for key, value in self.cache.items(): + for key, value in cache.items(): # Convert any simple object to string for JSON compatibility key_str = str(key) runner_id, tactic, min_time = value @@ -636,8 +735,6 @@ class AutoTuner: self._last_capture: Optional['AutoTuner.TacticsCapture'] = None # Dsitributed tuning state - self._map_op_to_distributed_strategy: Dict[ - str, DistributedTuningStrategy] = {} self._dist: Optional[Distributed] = None self._has_received_cache: bool = False self.mapping: Mapping = Mapping() @@ -844,8 +941,8 @@ class AutoTuner: "All Given runners must be subclass of TunableRunner" # Record the distributed tuning strategy for the custom_op - self._map_op_to_distributed_strategy[ - custom_op] = tuning_config.distributed_tuning_strategy + self.profiling_cache.add_independent_op( + custom_op, strategy=tuning_config.distributed_tuning_strategy) tuning_start_time = time.perf_counter() profiles = self._optimization_profiles(tuning_config, inputs) diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index c5021ef7f9..a41fe443a5 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -1,4 +1,5 @@ import itertools +import json import os import pickle import sys @@ -736,7 +737,10 @@ def _distributed_worker_function(world_size, strategy): # Each rank prefers different tactics prefer_tactics = [rank] runner = DistributedGemmRunner(prefer_tactics=prefer_tactics) + runner_independent = DistributedGemmRunner() config = TuningConfig(distributed_tuning_strategy=strategy) + config_independent = TuningConfig( + distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT) # Keep temp_dir in function scope to prevent premature garbage collection temp_dir = None @@ -749,40 +753,92 @@ def _distributed_worker_function(world_size, strategy): cache_path = dist.broadcast(None, root=0) with autotune(cache_path=cache_path): - tuner.choose_one(custom_op=f"test_distributed_{strategy}", + tuner.choose_one(custom_op=f"test_distributed_{strategy.value}", runners=[runner], tuning_config=config, inputs=inputs) + # run another normal gemm with INDEPENDENT strategy + tuner.choose_one(custom_op=f"test_distributed_normal_gemm", + runners=[runner_independent], + tuning_config=config_independent, + inputs=inputs) # Check only one file is created in the cache path assert len(os.listdir(os.path.dirname( cache_path))) == 1, "Only one rank file should be created" + dist.barrier() + # Check cache for distributed tuning AutoTuner.get().profiling_cache.clear() AutoTuner.get().profiling_cache.load_cache(cache_path, rank) selected_runner, best_tactic = tuner.choose_one( - custom_op=f"test_distributed_{strategy}", + custom_op=f"test_distributed_{strategy.value}", runners=[runner], tuning_config=config, inputs=inputs) + # Verify cache file structure based on distributed strategy + with open(cache_path, 'r') as f: + cache_data = json.load(f) + + # Helper to check if an op name appears in any cache key string + def has_op_in_section(section_data: dict, op_name: str) -> bool: + return any(op_name in key_str for key_str in section_data.keys()) + + assert 'metadata' in cache_data, "Metadata should be present" + assert f'rank_{rank}' in cache_data, f"rank {rank} should be present" + + # The INDEPENDENT op "test_distributed_normal_gemm" should always be in rank-specific sections + assert has_op_in_section(cache_data[f'rank_{rank}'], 'test_distributed_normal_gemm'), \ + f"rank {rank} should have test_distributed_normal_gemm" + + if strategy == DistributedTuningStrategy.INDEPENDENT: + # Both ops use INDEPENDENT strategy, so no shared section + assert 'shared' not in cache_data or len(cache_data.get('shared', {})) == 0, \ + "shared should not be present or be empty for INDEPENDENT strategy" + # Each rank should have 2 entries (the parameterized op + normal_gemm) + assert len(cache_data[f'rank_{rank}']) == 2, \ + f"rank {rank} should have 2 entries, got {len(cache_data[f'rank_{rank}'])}" + assert has_op_in_section(cache_data[f'rank_{rank}'], f'test_distributed_{strategy.value}'), \ + f"rank {rank} should have test_distributed_{strategy.value}" + + assert len( + AutoTuner.get().profiling_cache.independent_op + ) == 0, f"Non-INDEPENDENT ops should not be present in the cache" + else: + # Non-INDEPENDENT ops go to shared section + assert 'shared' in cache_data, "shared section should be present" + # Each rank should have only 1 entry (the normal_gemm with INDEPENDENT strategy) + assert len(cache_data[f'rank_{rank}']) == 1, \ + f"rank {rank} should have 1 entry, got {len(cache_data[f'rank_{rank}'])}" + # The parameterized op should NOT be in rank-specific section + assert not has_op_in_section(cache_data[f'rank_{rank}'], f'test_distributed_{strategy.value}'), \ + f"rank {rank} should not have test_distributed_{strategy.value}" + # The parameterized op should be in shared section + assert has_op_in_section(cache_data['shared'], f'test_distributed_{strategy.value}'), \ + f"shared should have test_distributed_{strategy.value}" + + assert "test_distributed_normal_gemm" not in AutoTuner.get().profiling_cache.independent_op and \ + f"test_distributed_{strategy.value}" in AutoTuner.get().profiling_cache.independent_op, \ + f"Distributed tuning strategy is not recovered correctly from cache" + if strategy == DistributedTuningStrategy.BROADCAST: # All ranks should select tactic 0 - assert best_tactic == 0 + assert best_tactic == 0, f"Rank {rank} with {strategy} should select tactic 0, got {best_tactic}" elif strategy == DistributedTuningStrategy.INDEPENDENT: # Each rank should select the tactic it prefers - assert best_tactic == rank + assert best_tactic == rank, f"Rank {rank} with {strategy} should select tactic {rank}, got {best_tactic}" elif strategy == DistributedTuningStrategy.MERGE: # Because tactic 0 is slower, two ranks should always select tactic 1 - assert best_tactic == 1 + assert best_tactic == 1, f"Rank {rank} with {strategy} should select tactic 1, got {best_tactic}" elif strategy == DistributedTuningStrategy.PARALLEL: # Tactic 1 or 3 should be selected since they are faster. # TODO: This might not cover the case that rank1 tunes nothing - assert best_tactic % 2 == 1 + assert best_tactic % 2 == 1, f"Rank {rank} with {strategy} should select tactic 1, got {best_tactic}" else: - assert False, f"Unknown strategy: {strategy}" + assert False, f"Rank {rank} got unknown strategy: {strategy}" dist.barrier() return True