[TRTLLM-10185][feat] AutoTuner Cache: Support cache file lock and merge all ranks into one (#10336)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-01-05 13:44:09 +08:00 committed by GitHub
parent 5a8bfcbb50
commit 0937df2c68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 49 deletions

View File

@ -2,6 +2,7 @@ import ast
import contextlib
import copy
import enum
import fcntl
import inspect
import itertools
import json
@ -266,15 +267,11 @@ def autotune(tune_mode: bool = True, cache_path: str = None):
tune_required = tune_mode
if cache_path is not None:
# check if the rank-specific file exists
cache_path_no_ext = os.path.splitext(cache_path)[0]
cache_path_no_ext_rank = cache_path_no_ext + f".rank{rank}.json"
# if the rank-specific file exists, load it
file_exists = os.path.exists(cache_path_no_ext_rank)
# if the rank-specific file exists, do not enable tuning mode
file_exists = os.path.exists(cache_path)
if file_exists:
logger.info(
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
autotuner.profiling_cache.load_cache(cache_path_no_ext_rank)
logger.info(f"[Autotuner] Loading cache from {cache_path}")
autotuner.profiling_cache.load_cache(cache_path, rank)
# record the old tuning mode
old_mode = autotuner.is_tuning_mode
@ -293,8 +290,8 @@ def autotune(tune_mode: bool = True, cache_path: str = None):
# save cache
if cache_path is not None:
logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}")
autotuner.profiling_cache.save_cache(cache_path_no_ext_rank)
logger.info(f"[Autotuner] Saving cache to {cache_path}")
autotuner.profiling_cache.save_cache(cache_path, rank)
@dataclass
@ -439,7 +436,7 @@ class AutoTunerProfilingCache:
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
return {k: v for k, v in self.cache.items() if k[0] == custom_op}
def save_cache(self, file_path: Union[str, Path]) -> None:
def save_cache(self, file_path: Union[str, Path], rank: int) -> None:
"""Save the profiling cache to disk in JSON format.
Args:
@ -456,9 +453,21 @@ class AutoTunerProfilingCache:
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
serializable_cache = self._serialize_cache_to_json()
with open(file_path, 'w') as f:
json.dump(serializable_cache, f, indent=2, default=str)
serialized_rank_cache_data = self._serialize_cache_data()
with open(file_path, 'a+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.seek(0)
content = f.read()
if content.strip():
current_cache = json.loads(content)
else:
current_cache = {
"metadata": self._serialize_metadata(),
}
f.seek(0)
f.truncate()
current_cache[f"rank_{rank}"] = serialized_rank_cache_data
json.dump(current_cache, f, indent=2, default=str)
logger.info(
f"[AutoTuner] Successfully saved cache to {file_path} using JSON format"
)
@ -466,7 +475,7 @@ class AutoTunerProfilingCache:
logger.error(f"[AutoTuner] Failed to save cache with JSON: {e}")
raise
def load_cache(self, file_path: Union[str, Path]) -> None:
def load_cache(self, file_path: Union[str, Path], rank: int) -> None:
"""Load the profiling cache from disk in JSON format.
Args:
@ -486,8 +495,12 @@ class AutoTunerProfilingCache:
try:
with open(file_path, 'r') as f:
serializable_cache = json.load(f)
self.cache = self._deserialize_cache_from_json(serializable_cache)
fcntl.flock(f, fcntl.LOCK_SH)
current_cache_contents = json.load(f)
self._deserialize_metadata(current_cache_contents["metadata"])
assert f"rank_{rank}" in current_cache_contents, f"Rank {rank} cache not found in {file_path}"
self.cache = self._deserialize_cache_data(
current_cache_contents[f'rank_{rank}'])
logger.info(
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format"
)
@ -495,7 +508,21 @@ class AutoTunerProfilingCache:
logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}")
raise
def _serialize_cache_to_json(self) -> Dict[str, Any]:
def _serialize_metadata(self) -> Dict[str, Any]:
return {
"lib_version": self.lib_version,
"creation_timestamp": self.creation_timestamp,
"device_name": self.device_name,
"device_capability": self.device_capability,
}
def _deserialize_metadata(self, metadata: Dict[str, Any]) -> None:
self.lib_version = metadata["lib_version"]
self.creation_timestamp = metadata["creation_timestamp"]
self.device_name = metadata["device_name"]
self.device_capability = metadata["device_capability"]
def _serialize_cache_data(self) -> Dict[str, Any]:
"""Convert the profiling cache to a JSON-serializable format.
Returns:
@ -505,15 +532,7 @@ class AutoTunerProfilingCache:
This method handles the conversion of complex objects to JSON-compatible
representations. Some type information may be lost in the conversion.
"""
serializable_cache = {
"metadata": {
"lib_version": self.lib_version,
"creation_timestamp": self.creation_timestamp,
"device_name": self.device_name,
"device_capability": self.device_capability,
},
"cache_data": {},
}
serializable_cache = {}
for key, value in self.cache.items():
# Convert any simple object to string for JSON compatibility
@ -529,7 +548,7 @@ class AutoTunerProfilingCache:
f"[AutoTuner] Could not serialize tactic: {tactic_str} for cache key {key_str} due to {e}. Deserialization may fail.",
key=tactic_str)
serializable_cache["cache_data"][key_str] = {
serializable_cache[key_str] = {
"runner_id": runner_id,
"tactic": tactic_str,
"min_time": min_time,
@ -537,8 +556,8 @@ class AutoTunerProfilingCache:
return serializable_cache
def _deserialize_cache_from_json(
self, serializable_cache: Dict[str, Any]) -> Dict[Tuple, Tuple]:
def _deserialize_cache_data(
self, cache_data: Dict[str, Any]) -> Dict[Tuple, Tuple]:
"""Convert JSON-serialized cache back to the original format.
Args:
@ -551,14 +570,7 @@ class AutoTunerProfilingCache:
This attempts to reconstruct the original data structures but may not
perfectly preserve all type information, especially for complex tactic objects.
"""
metadata = serializable_cache["metadata"]
self.lib_version = metadata["lib_version"]
self.creation_timestamp = metadata["creation_timestamp"]
self.device_name = metadata["device_name"]
self.device_capability = metadata["device_capability"]
cache = {}
cache_data = serializable_cache["cache_data"]
for key_str, value in cache_data.items():
# Reconstruct the tuple key safely

View File

@ -15,6 +15,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-disable_finalize_fusion-TRTLLM-dtype1]

View File

@ -17,8 +17,10 @@ from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy,
FakeTensor, OptimizationProfile,
StaticDim, TunableRunner,
TuningConfig, autotune)
from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
next_positive_power_of_2)
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
@ -323,8 +325,9 @@ def test_multiple_dynamic_shapes_cache():
# Do tuning with a sample input
x = torch.randn(3, 64)
temp_dir = tempfile.TemporaryDirectory()
with autotune(cache_path=os.path.join(temp_dir.name,
"test_multiple_dynamic_shapes.json")):
cache_path = os.path.join(temp_dir.name,
"test_multiple_dynamic_shapes.json")
with autotune(cache_path=cache_path):
tuner = AutoTuner.get()
runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes",
runners, tuning_config, [x, w])
@ -336,8 +339,7 @@ def test_multiple_dynamic_shapes_cache():
# Verify cache size - should have 12 entries (3x4 combinations)
# We also test the cache serialization and deserialization here.
AutoTuner.get().profiling_cache.clear()
AutoTuner.get().profiling_cache.load_cache(
os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json"))
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
cache_entries = tuner.profiling_cache.get_specific_custom_op(
"test_multiple_dynamic_shapes")
@ -427,8 +429,9 @@ def test_autotuner_tuning_configs():
use_cuda_graph=False,
)
temp_dir = tempfile.TemporaryDirectory()
with autotune(cache_path=os.path.join(
temp_dir.name, "test_autotuner_tactic_configs.json")):
cache_path = os.path.join(temp_dir.name,
"test_autotuner_tactic_configs.json")
with autotune(cache_path=cache_path):
tuner = AutoTuner.get()
runner, best_tactic = tuner.choose_one("test_autotuner_tactic_configs",
runners, tuning_config, [x, w])
@ -437,8 +440,7 @@ def test_autotuner_tuning_configs():
# Test if the tactic can be loaded from cache correctly
AutoTuner.get().profiling_cache.clear()
AutoTuner.get().profiling_cache.load_cache(
os.path.join(temp_dir.name, "test_autotuner_tactic_configs.rank0.json"))
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
# No further tuning should be performed.
runner, deserialized_tactic = tuner.choose_one(
@ -646,9 +648,14 @@ def _distributed_worker_function(world_size, strategy):
rank=rank,
tp_size=world_size,
pp_size=1)
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
tuner = AutoTuner.get()
tuner.clear_cache()
tuner.setup_distributed_state(mapping)
tuner.setup_distributed_state(mapping, dist)
x = torch.randn(16, 32, device='cuda')
w = torch.randn(32, 64, device='cuda')
@ -663,12 +670,28 @@ def _distributed_worker_function(world_size, strategy):
runner = DistributedGemmRunner(prefer_tactics=prefer_tactics)
config = TuningConfig(distributed_tuning_strategy=strategy)
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
with autotune(tune_mode=True, cache_path=cache_path):
if rank == 0:
temp_dir = tempfile.TemporaryDirectory()
# rank 0 should broadcast the cache path to all ranks
cache_path = os.path.join(temp_dir.name, "test_distributed_tuning.json")
dist.broadcast(cache_path, root=0)
else:
cache_path = dist.broadcast(None, root=0)
with autotune(cache_path=cache_path):
tuner.choose_one(custom_op=f"test_distributed_{strategy}",
runners=[runner],
tuning_config=config,
inputs=inputs)
# Check only one file is created in the cache path
assert len(os.listdir(os.path.dirname(
cache_path))) == 1, "Only one rank file should be created"
# Check cache for distributed tuning
AutoTuner.get().profiling_cache.clear()
AutoTuner.get().profiling_cache.load_cache(cache_path, rank)
selected_runner, best_tactic = tuner.choose_one(
custom_op=f"test_distributed_{strategy}",
runners=[runner],
@ -706,8 +729,7 @@ def _distributed_worker_function(world_size, strategy):
],
)
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_distributed_broadcast_strategy(strategy, mpi_pool_executor):
"""Test broadcast strategy with real MPI processes."""
def test_autotuner_distributed_strategy(strategy, mpi_pool_executor):
world_size = 2
# Use MPIPoolExecutor to run distributed test
results = mpi_pool_executor.map(