mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10185][feat] AutoTuner Cache: Support cache file lock and merge all ranks into one (#10336)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
5a8bfcbb50
commit
0937df2c68
@ -2,6 +2,7 @@ import ast
|
||||
import contextlib
|
||||
import copy
|
||||
import enum
|
||||
import fcntl
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
@ -266,15 +267,11 @@ def autotune(tune_mode: bool = True, cache_path: str = None):
|
||||
tune_required = tune_mode
|
||||
if cache_path is not None:
|
||||
# check if the rank-specific file exists
|
||||
cache_path_no_ext = os.path.splitext(cache_path)[0]
|
||||
cache_path_no_ext_rank = cache_path_no_ext + f".rank{rank}.json"
|
||||
# if the rank-specific file exists, load it
|
||||
file_exists = os.path.exists(cache_path_no_ext_rank)
|
||||
# if the rank-specific file exists, do not enable tuning mode
|
||||
file_exists = os.path.exists(cache_path)
|
||||
if file_exists:
|
||||
logger.info(
|
||||
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
|
||||
autotuner.profiling_cache.load_cache(cache_path_no_ext_rank)
|
||||
logger.info(f"[Autotuner] Loading cache from {cache_path}")
|
||||
autotuner.profiling_cache.load_cache(cache_path, rank)
|
||||
|
||||
# record the old tuning mode
|
||||
old_mode = autotuner.is_tuning_mode
|
||||
@ -293,8 +290,8 @@ def autotune(tune_mode: bool = True, cache_path: str = None):
|
||||
|
||||
# save cache
|
||||
if cache_path is not None:
|
||||
logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}")
|
||||
autotuner.profiling_cache.save_cache(cache_path_no_ext_rank)
|
||||
logger.info(f"[Autotuner] Saving cache to {cache_path}")
|
||||
autotuner.profiling_cache.save_cache(cache_path, rank)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -439,7 +436,7 @@ class AutoTunerProfilingCache:
|
||||
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
|
||||
return {k: v for k, v in self.cache.items() if k[0] == custom_op}
|
||||
|
||||
def save_cache(self, file_path: Union[str, Path]) -> None:
|
||||
def save_cache(self, file_path: Union[str, Path], rank: int) -> None:
|
||||
"""Save the profiling cache to disk in JSON format.
|
||||
|
||||
Args:
|
||||
@ -456,9 +453,21 @@ class AutoTunerProfilingCache:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
serializable_cache = self._serialize_cache_to_json()
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(serializable_cache, f, indent=2, default=str)
|
||||
serialized_rank_cache_data = self._serialize_cache_data()
|
||||
with open(file_path, 'a+') as f:
|
||||
fcntl.flock(f, fcntl.LOCK_EX)
|
||||
f.seek(0)
|
||||
content = f.read()
|
||||
if content.strip():
|
||||
current_cache = json.loads(content)
|
||||
else:
|
||||
current_cache = {
|
||||
"metadata": self._serialize_metadata(),
|
||||
}
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
current_cache[f"rank_{rank}"] = serialized_rank_cache_data
|
||||
json.dump(current_cache, f, indent=2, default=str)
|
||||
logger.info(
|
||||
f"[AutoTuner] Successfully saved cache to {file_path} using JSON format"
|
||||
)
|
||||
@ -466,7 +475,7 @@ class AutoTunerProfilingCache:
|
||||
logger.error(f"[AutoTuner] Failed to save cache with JSON: {e}")
|
||||
raise
|
||||
|
||||
def load_cache(self, file_path: Union[str, Path]) -> None:
|
||||
def load_cache(self, file_path: Union[str, Path], rank: int) -> None:
|
||||
"""Load the profiling cache from disk in JSON format.
|
||||
|
||||
Args:
|
||||
@ -486,8 +495,12 @@ class AutoTunerProfilingCache:
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
serializable_cache = json.load(f)
|
||||
self.cache = self._deserialize_cache_from_json(serializable_cache)
|
||||
fcntl.flock(f, fcntl.LOCK_SH)
|
||||
current_cache_contents = json.load(f)
|
||||
self._deserialize_metadata(current_cache_contents["metadata"])
|
||||
assert f"rank_{rank}" in current_cache_contents, f"Rank {rank} cache not found in {file_path}"
|
||||
self.cache = self._deserialize_cache_data(
|
||||
current_cache_contents[f'rank_{rank}'])
|
||||
logger.info(
|
||||
f"[AutoTuner] Successfully loaded cache from {file_path} using JSON format"
|
||||
)
|
||||
@ -495,7 +508,21 @@ class AutoTunerProfilingCache:
|
||||
logger.error(f"[AutoTuner] Failed to load cache with JSON: {e}")
|
||||
raise
|
||||
|
||||
def _serialize_cache_to_json(self) -> Dict[str, Any]:
|
||||
def _serialize_metadata(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"lib_version": self.lib_version,
|
||||
"creation_timestamp": self.creation_timestamp,
|
||||
"device_name": self.device_name,
|
||||
"device_capability": self.device_capability,
|
||||
}
|
||||
|
||||
def _deserialize_metadata(self, metadata: Dict[str, Any]) -> None:
|
||||
self.lib_version = metadata["lib_version"]
|
||||
self.creation_timestamp = metadata["creation_timestamp"]
|
||||
self.device_name = metadata["device_name"]
|
||||
self.device_capability = metadata["device_capability"]
|
||||
|
||||
def _serialize_cache_data(self) -> Dict[str, Any]:
|
||||
"""Convert the profiling cache to a JSON-serializable format.
|
||||
|
||||
Returns:
|
||||
@ -505,15 +532,7 @@ class AutoTunerProfilingCache:
|
||||
This method handles the conversion of complex objects to JSON-compatible
|
||||
representations. Some type information may be lost in the conversion.
|
||||
"""
|
||||
serializable_cache = {
|
||||
"metadata": {
|
||||
"lib_version": self.lib_version,
|
||||
"creation_timestamp": self.creation_timestamp,
|
||||
"device_name": self.device_name,
|
||||
"device_capability": self.device_capability,
|
||||
},
|
||||
"cache_data": {},
|
||||
}
|
||||
serializable_cache = {}
|
||||
|
||||
for key, value in self.cache.items():
|
||||
# Convert any simple object to string for JSON compatibility
|
||||
@ -529,7 +548,7 @@ class AutoTunerProfilingCache:
|
||||
f"[AutoTuner] Could not serialize tactic: {tactic_str} for cache key {key_str} due to {e}. Deserialization may fail.",
|
||||
key=tactic_str)
|
||||
|
||||
serializable_cache["cache_data"][key_str] = {
|
||||
serializable_cache[key_str] = {
|
||||
"runner_id": runner_id,
|
||||
"tactic": tactic_str,
|
||||
"min_time": min_time,
|
||||
@ -537,8 +556,8 @@ class AutoTunerProfilingCache:
|
||||
|
||||
return serializable_cache
|
||||
|
||||
def _deserialize_cache_from_json(
|
||||
self, serializable_cache: Dict[str, Any]) -> Dict[Tuple, Tuple]:
|
||||
def _deserialize_cache_data(
|
||||
self, cache_data: Dict[str, Any]) -> Dict[Tuple, Tuple]:
|
||||
"""Convert JSON-serialized cache back to the original format.
|
||||
|
||||
Args:
|
||||
@ -551,14 +570,7 @@ class AutoTunerProfilingCache:
|
||||
This attempts to reconstruct the original data structures but may not
|
||||
perfectly preserve all type information, especially for complex tactic objects.
|
||||
"""
|
||||
metadata = serializable_cache["metadata"]
|
||||
self.lib_version = metadata["lib_version"]
|
||||
self.creation_timestamp = metadata["creation_timestamp"]
|
||||
self.device_name = metadata["device_name"]
|
||||
self.device_capability = metadata["device_capability"]
|
||||
|
||||
cache = {}
|
||||
cache_data = serializable_cache["cache_data"]
|
||||
|
||||
for key_str, value in cache_data.items():
|
||||
# Reconstruct the tuple key safely
|
||||
|
||||
@ -15,6 +15,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-disable_finalize_fusion-TRTLLM-dtype1]
|
||||
|
||||
@ -17,8 +17,10 @@ from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy,
|
||||
FakeTensor, OptimizationProfile,
|
||||
StaticDim, TunableRunner,
|
||||
TuningConfig, autotune)
|
||||
from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
|
||||
next_positive_power_of_2)
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.internal.runtime import delay_kernel
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -323,8 +325,9 @@ def test_multiple_dynamic_shapes_cache():
|
||||
# Do tuning with a sample input
|
||||
x = torch.randn(3, 64)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
with autotune(cache_path=os.path.join(temp_dir.name,
|
||||
"test_multiple_dynamic_shapes.json")):
|
||||
cache_path = os.path.join(temp_dir.name,
|
||||
"test_multiple_dynamic_shapes.json")
|
||||
with autotune(cache_path=cache_path):
|
||||
tuner = AutoTuner.get()
|
||||
runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes",
|
||||
runners, tuning_config, [x, w])
|
||||
@ -336,8 +339,7 @@ def test_multiple_dynamic_shapes_cache():
|
||||
# Verify cache size - should have 12 entries (3x4 combinations)
|
||||
# We also test the cache serialization and deserialization here.
|
||||
AutoTuner.get().profiling_cache.clear()
|
||||
AutoTuner.get().profiling_cache.load_cache(
|
||||
os.path.join(temp_dir.name, "test_multiple_dynamic_shapes.rank0.json"))
|
||||
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
|
||||
cache_entries = tuner.profiling_cache.get_specific_custom_op(
|
||||
"test_multiple_dynamic_shapes")
|
||||
|
||||
@ -427,8 +429,9 @@ def test_autotuner_tuning_configs():
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
with autotune(cache_path=os.path.join(
|
||||
temp_dir.name, "test_autotuner_tactic_configs.json")):
|
||||
cache_path = os.path.join(temp_dir.name,
|
||||
"test_autotuner_tactic_configs.json")
|
||||
with autotune(cache_path=cache_path):
|
||||
tuner = AutoTuner.get()
|
||||
runner, best_tactic = tuner.choose_one("test_autotuner_tactic_configs",
|
||||
runners, tuning_config, [x, w])
|
||||
@ -437,8 +440,7 @@ def test_autotuner_tuning_configs():
|
||||
|
||||
# Test if the tactic can be loaded from cache correctly
|
||||
AutoTuner.get().profiling_cache.clear()
|
||||
AutoTuner.get().profiling_cache.load_cache(
|
||||
os.path.join(temp_dir.name, "test_autotuner_tactic_configs.rank0.json"))
|
||||
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
|
||||
|
||||
# No further tuning should be performed.
|
||||
runner, deserialized_tactic = tuner.choose_one(
|
||||
@ -646,9 +648,14 @@ def _distributed_worker_function(world_size, strategy):
|
||||
rank=rank,
|
||||
tp_size=world_size,
|
||||
pp_size=1)
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
tuner.clear_cache()
|
||||
tuner.setup_distributed_state(mapping)
|
||||
tuner.setup_distributed_state(mapping, dist)
|
||||
|
||||
x = torch.randn(16, 32, device='cuda')
|
||||
w = torch.randn(32, 64, device='cuda')
|
||||
@ -663,12 +670,28 @@ def _distributed_worker_function(world_size, strategy):
|
||||
runner = DistributedGemmRunner(prefer_tactics=prefer_tactics)
|
||||
config = TuningConfig(distributed_tuning_strategy=strategy)
|
||||
|
||||
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
|
||||
with autotune(tune_mode=True, cache_path=cache_path):
|
||||
if rank == 0:
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
# rank 0 should broadcast the cache path to all ranks
|
||||
cache_path = os.path.join(temp_dir.name, "test_distributed_tuning.json")
|
||||
dist.broadcast(cache_path, root=0)
|
||||
else:
|
||||
cache_path = dist.broadcast(None, root=0)
|
||||
|
||||
with autotune(cache_path=cache_path):
|
||||
tuner.choose_one(custom_op=f"test_distributed_{strategy}",
|
||||
runners=[runner],
|
||||
tuning_config=config,
|
||||
inputs=inputs)
|
||||
|
||||
# Check only one file is created in the cache path
|
||||
assert len(os.listdir(os.path.dirname(
|
||||
cache_path))) == 1, "Only one rank file should be created"
|
||||
|
||||
# Check cache for distributed tuning
|
||||
AutoTuner.get().profiling_cache.clear()
|
||||
AutoTuner.get().profiling_cache.load_cache(cache_path, rank)
|
||||
|
||||
selected_runner, best_tactic = tuner.choose_one(
|
||||
custom_op=f"test_distributed_{strategy}",
|
||||
runners=[runner],
|
||||
@ -706,8 +729,7 @@ def _distributed_worker_function(world_size, strategy):
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
def test_distributed_broadcast_strategy(strategy, mpi_pool_executor):
|
||||
"""Test broadcast strategy with real MPI processes."""
|
||||
def test_autotuner_distributed_strategy(strategy, mpi_pool_executor):
|
||||
world_size = 2
|
||||
# Use MPIPoolExecutor to run distributed test
|
||||
results = mpi_pool_executor.map(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user