[None][fix] Fix on-disk cache and revise logger/statistics for AutoTuner. (#9211)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-11-28 13:32:21 +08:00 committed by GitHub
parent c87e81c1d8
commit 60c43a200a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 163 additions and 142 deletions

View File

@ -98,6 +98,7 @@ class TuningConfig:
use_cold_l2_cache (bool): Whether to use cold L2 cache. use_cold_l2_cache (bool): Whether to use cold L2 cache.
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache. This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
Notice that not all tuning processes can benefit from this feature. Notice that not all tuning processes can benefit from this feature.
use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
""" """
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
constraint_specs: Tuple[ConstraintSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = ()
@ -211,8 +212,16 @@ class TunableRunner(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def __hash__(self): def unique_id(self):
return hash(tuple(self.__dict__.values())) """
Returns a tuple of the unique id of the runner. The unique id will be converted to a string for the cache key.
A common practice is to return a tuple of the runner's attributes, for example:
return (self.output_dtype, self.attribute_1, ...)
Returns:
Any: The unique id of the runner, which can be converted to a string for the cache key.
"""
return tuple(self.__dict__.values())
@contextlib.contextmanager @contextlib.contextmanager
@ -226,7 +235,6 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
# if the rank-specific file exists, load it # if the rank-specific file exists, load it
file_exists = os.path.exists(cache_path_no_ext_rank) file_exists = os.path.exists(cache_path_no_ext_rank)
# if the rank-specific file exists, do not enable tuning mode # if the rank-specific file exists, do not enable tuning mode
tune_required = tune_required and not os.path.exists(cache_path)
if file_exists: if file_exists:
logger.info( logger.info(
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}") f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
@ -259,8 +267,8 @@ class AutoTunerStatistics:
cache_misses (int): Number of cache misses requiring fallback cache_misses (int): Number of cache misses requiring fallback
cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses
failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation
tuned_op_total_configs (Dict[str, int]): Total configurations tried per operation tuned_op_profiled_configs (Dict[str, int]): Profiled configurations per operation
tuned_op_successful_configs (Dict[str, int]): Successful configurations per operation tuned_op_time_cost (Dict[str, float]): Time cost per operation
""" """
cache_misses: int = 0 cache_misses: int = 0
cache_miss_config_collection: Dict[str, cache_miss_config_collection: Dict[str,
@ -268,8 +276,8 @@ class AutoTunerStatistics:
failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner, failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner,
OptimizationProfile]]] = field( OptimizationProfile]]] = field(
default_factory=dict) default_factory=dict)
tuned_op_total_configs: Dict[str, int] = field(default_factory=dict) tuned_op_profiled_configs: Dict[str, int] = field(default_factory=dict)
tuned_op_successful_configs: Dict[str, int] = field(default_factory=dict) tuned_op_time_cost: Dict[str, float] = field(default_factory=dict)
def __str__(self) -> str: def __str__(self) -> str:
"""Return a string representation of collected statistics. """Return a string representation of collected statistics.
@ -284,22 +292,23 @@ class AutoTunerStatistics:
for profile in sorted(profiles, key=str): for profile in sorted(profiles, key=str):
stats_str += f" - Config: {profile}\n" stats_str += f" - Config: {profile}\n"
if self.tuned_op_total_configs: if self.tuned_op_profiled_configs:
stats_str += "Tuned operations:\n" stats_str += "Tuned operations:\n"
for op in sorted(self.tuned_op_total_configs.keys()): for op in sorted(self.tuned_op_profiled_configs.keys()):
total = self.tuned_op_total_configs[op] successful = self.tuned_op_profiled_configs[op]
successful = self.tuned_op_successful_configs.get(op, 0) failed = len(self.failed_profiling_count[op])
failed = len(self.failed_profiling_count.get(op, set()))
success_rate = (successful / total * 100) if total > 0 else 0
stats_str += f" {op}:\n" stats_str += f" {op}:\n"
stats_str += f" - Total configs tried: {total}\n"
stats_str += f" - Successful configs: {successful}\n" stats_str += f" - Successful configs: {successful}\n"
stats_str += f" - Failed profiling count: {failed}\n" stats_str += f" - Failed profiling count: {failed}\n"
if failed > 0: if failed > 0:
stats_str += f" - Failed profiling combinations:\n" stats_str += f" - Failed profiling combinations:\n"
for failed_key in self.failed_profiling_count[op]: for failed_key in self.failed_profiling_count[op]:
stats_str += f" - {failed_key}\n" stats_str += f" - {failed_key}\n"
stats_str += f" - Success rate: {success_rate:.1f}%\n"
if self.tuned_op_time_cost:
stats_str += "Tuned operations time cost:\n"
for op in sorted(self.tuned_op_time_cost.keys()):
stats_str += f" {op}: {self.tuned_op_time_cost[op] * 1000:.4f} milliseconds\n"
return stats_str return stats_str
@ -374,7 +383,7 @@ class AutoTunerProfilingCache:
return ( return (
custom_op, custom_op,
runner.__class__.__name__, runner.__class__.__name__,
hash(runner), str(runner.unique_id()),
AutoTuner.get()._find_nearest_profile( AutoTuner.get()._find_nearest_profile(
input_shapes, input_shapes,
tuning_config.dynamic_tensor_specs, tuning_config.dynamic_tensor_specs,
@ -546,6 +555,11 @@ class AutoTuner:
# Last captured choose_one() contexts # Last captured choose_one() contexts
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
# Increase log level for AutoTuner associated logger
self._log_level_to_info = os.getenv(
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
self._debug_logger = logger.info if self._log_level_to_info else logger.debug
@classmethod @classmethod
def get(cls): def get(cls):
if cls._instance is None: if cls._instance is None:
@ -726,11 +740,14 @@ class AutoTuner:
assert all([isinstance(r, TunableRunner) for r in runners]), \ assert all([isinstance(r, TunableRunner) for r in runners]), \
"All Given runners must be subclass of TunableRunner" "All Given runners must be subclass of TunableRunner"
tuning_start_time = time.perf_counter()
profiles = self._optimization_profiles(tuning_config, inputs) profiles = self._optimization_profiles(tuning_config, inputs)
# Record the total configs to try # Initialize the statistics for the custom_op
self.stats.tuned_op_total_configs[custom_op] = len(profiles) if custom_op not in self.stats.tuned_op_profiled_configs:
self.stats.tuned_op_profiled_configs[custom_op] = 0
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
new_tuning_failure_occured = False new_tuning_failure_occured = False
for p in profiles: for p in profiles:
@ -746,16 +763,15 @@ class AutoTuner:
cache_key = self.profiling_cache.get_cache_key( cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], p.get_opt_shapes(), custom_op, runners[best_runner_id], p.get_opt_shapes(),
tuning_config) tuning_config)
self._debug_logger(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
# inspect call stack # inspect call stack
self.profiling_cache[cache_key] = (best_runner_id, self.profiling_cache[cache_key] = (best_runner_id,
best_tactic, min_time) best_tactic, min_time)
self.stats.tuned_op_successful_configs[ self.stats.tuned_op_profiled_configs[custom_op] += 1
custom_op] = self.stats.tuned_op_successful_configs.get(
custom_op, 0) + 1
logger.debug(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
)
else: else:
logger.warning_once( logger.warning_once(
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. " f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
@ -782,6 +798,10 @@ class AutoTuner:
_, runner_id, tactic, _ = self.profiling_cache.search_cache( _, runner_id, tactic, _ = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config) custom_op, runners, input_shapes, tuning_config)
tuning_end_time = time.perf_counter()
self.stats.tuned_op_time_cost[
custom_op] = self.stats.tuned_op_time_cost.get(
custom_op, 0) + tuning_end_time - tuning_start_time
return (runners[runner_id], tactic) return (runners[runner_id], tactic)
def _profile_runners( def _profile_runners(
@ -832,14 +852,13 @@ class AutoTuner:
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.", f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.",
key=(custom_op, "warning_autotuning_profile_failure"), key=(custom_op, "warning_autotuning_profile_failure"),
) )
logger.debug_once( (logger.info_once
f"[Autotuner] Exception captured: {e}", if self._log_level_to_info else logger.debug_once)(
key=(custom_op, "debug_autotuning_exception"), f"[Autotuner] Exception captured: {e}",
) key=(custom_op, "debug_autotuning_exception"),
)
# Record the failed profiling combinations # Record the failed profiling combinations
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[custom_op] = set()
self.stats.failed_profiling_count[custom_op].add( self.stats.failed_profiling_count[custom_op].add(
self.profiling_cache.get_cache_key( self.profiling_cache.get_cache_key(
custom_op, runner, profile.get_opt_shapes(), custom_op, runner, profile.get_opt_shapes(),
@ -957,7 +976,7 @@ class AutoTuner:
avg_time = pure_profile(stream, self.repeat) avg_time = pure_profile(stream, self.repeat)
shapes = self._get_input_sizes(inputs) shapes = self._get_input_sizes(inputs)
logger.debug( self._debug_logger(
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms." f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms."
) )
@ -1043,7 +1062,7 @@ class AutoTuner:
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim( p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
min_value, opt_value, max_value) min_value, opt_value, max_value)
generated_profiles.append(p) generated_profiles.append(p)
logger.debug(f"[Autotuner] Generated profile: {p}") self._debug_logger(f"[Autotuner] Generated profile: {p}")
return generated_profiles return generated_profiles
@classmethod @classmethod
@ -1159,7 +1178,7 @@ class AutoTuner:
input.element_size() if isinstance(input, torch.Tensor) else 0 input.element_size() if isinstance(input, torch.Tensor) else 0
for input in inputs) for input in inputs)
if one_buffer_bytes <= 0: if one_buffer_bytes <= 0:
logger.debug( self._debug_logger(
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling." "[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
) )
return [inputs] return [inputs]
@ -1174,7 +1193,7 @@ class AutoTuner:
list(t.clone() if isinstance(t, torch.Tensor) else t list(t.clone() if isinstance(t, torch.Tensor) else t
for t in inputs)) for t in inputs))
logger.debug( self._debug_logger(
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling" f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
) )
return inputs_list return inputs_list
@ -1188,16 +1207,23 @@ class AutoTuner:
self.stats = AutoTunerStatistics() self.stats = AutoTunerStatistics()
def print_profiling_cache(self): def print_profiling_cache(self):
logger.debug(f"[Autotuner] The profiling_cache entries:") self._debug_logger(f"[Autotuner] The profiling_cache entries:")
logger.debug( self._debug_logger(
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))" f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
) )
for key, value in self.profiling_cache.cache.items(): for key, value in self.profiling_cache.cache.items():
runner_id, tactic, min_time = value runner_id, tactic, min_time = value
logger.debug( self._debug_logger(
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})" f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
) )
self.print_statistics()
def print_statistics(self):
self._debug_logger(f"[Autotuner] The statistics:")
for line in self.stats.__str__().split("\n"):
self._debug_logger(line)
@contextlib.contextmanager @contextlib.contextmanager
def capture(self): def capture(self):
"""Context manager for capturing execution contexts for testing. """Context manager for capturing execution contexts for testing.
@ -1271,7 +1297,7 @@ class AutoTuner:
runner_idx = runners.index(runner) runner_idx = runners.index(runner)
runner_tactic_list.append((runner_idx, tactic)) runner_tactic_list.append((runner_idx, tactic))
logger.debug( self._debug_logger(
f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}") f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}")
# Replay the contexts with given (runner, tactic) pairs # Replay the contexts with given (runner, tactic) pairs

View File

@ -55,13 +55,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
) )
# rewrite the hash function because the value of self.alpha doesn't affect the tactic. # rewrite the hash function because the value of self.alpha doesn't affect the tactic.
def __hash__(self): def unique_id(self):
return hash((self.output_dtype, )) return (self.output_dtype, )
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.output_dtype == other.output_dtype
def get_valid_tactics( def get_valid_tactics(
self, self,

View File

@ -93,6 +93,29 @@ class MoERunner(TunableRunner):
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"])) return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"]))
def unique_id(self):
return (
self.x_dtype,
self.weight_dtype,
self.output_dtype,
self.top_k,
self.tp_size,
self.tp_rank,
self.ep_size,
self.ep_rank,
self.cluster_size,
self.cluster_rank,
self.enable_alltoall,
self.use_deepseek_fp8_block_scale,
self.use_w4_group_scaling,
self.use_int8_woq_per_channel,
self.use_mxfp8_act_scaling,
self.min_latency_mode,
self.use_fused_finalize,
self.activation_type,
self.unpadded_hidden_size,
)
def forward( def forward(
self, self,
inputs: List[torch.Tensor], inputs: List[torch.Tensor],
@ -316,6 +339,12 @@ class FP8RowwiseGemmRunner(TunableRunner):
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[ self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
instance_key] instance_key]
def unique_id(self):
return (
self.to_userbuffers,
self.output_dtype,
)
def get_valid_tactics(self, inputs: List[torch.Tensor], def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs())) return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))
@ -398,6 +427,12 @@ class FP4GemmRunner(TunableRunner):
output_dtype, int(fp4_gemm_type)) output_dtype, int(fp4_gemm_type))
self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key] self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key]
def unique_id(self):
return (
self.to_userbuffers,
self.output_dtype,
)
def get_valid_tactics(self, inputs: List[torch.Tensor], def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.fp4_gemm_runner.get_num_configs())) return list(range(self.fp4_gemm_runner.get_num_configs()))
@ -447,6 +482,12 @@ class CublasLtFP4GemmRunner(TunableRunner):
self.cublaslt_runner = CublasLtFP4GemmRunner.runner_dict[instance_key] self.cublaslt_runner = CublasLtFP4GemmRunner.runner_dict[instance_key]
def unique_id(self):
return hash((
self.to_userbuffers,
self.output_dtype,
))
def get_valid_tactics(self, inputs: List[torch.Tensor], def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
"""Get all valid tactics (algorithms) from cuBLASLt heuristic.""" """Get all valid tactics (algorithms) from cuBLASLt heuristic."""
@ -592,6 +633,15 @@ class FP8BatchedGemmRunner(TunableRunner):
self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key] self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key]
def unique_id(self):
return (
self.output_dtype,
self.use_deep_seek_fp8,
self.low_latency_kernel,
self.tile_size,
self.epilogue_tile_m,
)
def forward( def forward(
self, self,
inputs: List[torch.Tensor], inputs: List[torch.Tensor],
@ -827,6 +877,12 @@ class WeightOnlyQuantGemmRunner(TunableRunner):
self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[ self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[
instance_key] instance_key]
def unique_id(self):
return (
self.output_dtype,
self.to_userbuffers,
)
def get_valid_tactics(self, inputs: List[torch.Tensor], def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.weight_only_quant_gemm_runner.get_num_configs())) return list(range(self.weight_only_quant_gemm_runner.get_num_configs()))
@ -894,6 +950,9 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype, def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
quant_mode: int): quant_mode: int):
self.activation_dtype = activation_dtype
self.output_dtype = output_dtype
self.quant_mode = quant_mode
instance_key = (activation_dtype, output_dtype, quant_mode) instance_key = (activation_dtype, output_dtype, quant_mode)
if instance_key not in FinegrainedMixedDtypeGemm._runner_dict: if instance_key not in FinegrainedMixedDtypeGemm._runner_dict:
FinegrainedMixedDtypeGemm._runner_dict[ FinegrainedMixedDtypeGemm._runner_dict[
@ -902,6 +961,13 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[ self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
instance_key] instance_key]
def unique_id(self):
return (
self.activation_dtype,
self.output_dtype,
self.quant_mode,
)
def get_valid_tactics(self, inputs: List[torch.Tensor], def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]: profile: OptimizationProfile, **kwargs) -> List[int]:
return list( return list(
@ -1012,6 +1078,12 @@ class fp8SwapABGemmRunner(TunableRunner):
self.output_dtype = output_dtype self.output_dtype = output_dtype
self.disable_ue8m0_cast = disable_ue8m0_cast self.disable_ue8m0_cast = disable_ue8m0_cast
def unique_id(self):
return (
self.output_dtype,
self.disable_ue8m0_cast,
)
def get_valid_tactics( def get_valid_tactics(
self, self,
inputs: List[torch.Tensor], inputs: List[torch.Tensor],

View File

@ -191,24 +191,10 @@ class FP4BlockScaleMoERunner(TunableRunner):
FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config( FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members # The unique_id is used by the autotuner to get the cache key, so we hash on members
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing type does not matter
# so the routing type does not matter def unique_id(self):
def __hash__(self): return (self.top_k, self.intermediate_size, self.local_num_experts)
return hash((
self.top_k,
self.intermediate_size,
self.local_num_experts,
))
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, FP4BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.local_num_experts == other.local_num_experts)
def get_runner(self): def get_runner(self):
instance_key = () instance_key = ()
@ -558,24 +544,11 @@ class FP8BlockScaleMoERunner(TunableRunner):
FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members # The unique_id is used by the autotuner to get the cache key, so we hash on members
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
# type does not matter # type does not matter
def __hash__(self): def unique_id(self):
return hash(( return (self.top_k, self.intermediate_size, self.local_num_experts)
self.top_k,
self.intermediate_size,
self.local_num_experts,
))
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, FP8BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.local_num_experts == other.local_num_experts)
def get_runner(self): def get_runner(self):
instance_key = () instance_key = ()
@ -845,30 +818,18 @@ class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner):
MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config( MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members # The unique_id is used by the autotuner to get the cache key, so we hash on members
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
# type does not matter # type does not matter
def __hash__(self): def unique_id(self):
return hash(( return (
self.top_k, self.top_k,
self.intermediate_size, self.intermediate_size,
self.valid_hidden_size, self.valid_hidden_size,
self.valid_intermediate_size, self.valid_intermediate_size,
self.local_num_experts, self.local_num_experts,
self.act_type, self.act_type,
)) )
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, MxE4m3MxE2m1BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.valid_hidden_size == other.valid_hidden_size and
self.valid_intermediate_size == other.valid_intermediate_size
and self.local_num_experts == other.local_num_experts
and self.act_type == other.act_type)
def get_runner(self): def get_runner(self):
instance_key = (self.act_type, True) instance_key = (self.act_type, True)
@ -1145,30 +1106,18 @@ class E4m3MxE2m1BlockScaleMoERunner(TunableRunner):
E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config( E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members # The unique_id is used by the autotuner to get the cache key, so we hash on members
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
# type does not matter # type does not matter
def __hash__(self): def unique_id(self):
return hash(( return (
self.top_k, self.top_k,
self.intermediate_size, self.intermediate_size,
self.valid_hidden_size, self.valid_hidden_size,
self.valid_intermediate_size, self.valid_intermediate_size,
self.local_num_experts, self.local_num_experts,
self.act_type, self.act_type,
)) )
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, E4m3MxE2m1BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.valid_hidden_size == other.valid_hidden_size and
self.valid_intermediate_size == other.valid_intermediate_size
and self.local_num_experts == other.local_num_experts
and self.act_type == other.act_type)
def get_runner(self): def get_runner(self):
instance_key = (self.act_type, False) instance_key = (self.act_type, False)
@ -1425,10 +1374,10 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config( Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members # The unique_id is used by the autotuner to get the cache key, so we hash on members
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
# type does not matter # type does not matter
def __hash__(self): def unique_id(self):
return hash(( return hash((
self.top_k, self.top_k,
self.intermediate_size, self.intermediate_size,
@ -1438,18 +1387,6 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
self.act_type, self.act_type,
)) ))
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, Bf16MxE2m1BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.valid_hidden_size == other.valid_hidden_size and
self.valid_intermediate_size == other.valid_intermediate_size
and self.local_num_experts == other.local_num_experts
and self.act_type == other.act_type)
def get_runner(self): def get_runner(self):
instance_key = (self.act_type, ) instance_key = (self.act_type, )
if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict: if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict:
@ -1695,26 +1632,13 @@ class FP8FP4BlockScaleMoERunner(TunableRunner):
FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config( FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config(
) )
# The hash is used by the autotuner to get the cache key, so we hash on members def unique_id(self):
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 return (
# so the routing type does not matter
def __hash__(self):
return hash((
self.top_k, self.top_k,
self.intermediate_size, self.intermediate_size,
self.local_num_experts, self.local_num_experts,
self.act_type, self.act_type,
)) )
# __eq__ and __hash__ must agree
def __eq__(self, other):
if not isinstance(other, FP8FP4BlockScaleMoERunner):
return False
return (self.top_k == other.top_k
and self.intermediate_size == other.intermediate_size
and self.local_num_experts == other.local_num_experts
and self.act_type == other.act_type)
def get_runner(self): def get_runner(self):
instance_key = (self.act_type, ) instance_key = (self.act_type, )

View File

@ -140,6 +140,10 @@ def test_autotuner_cache_basic():
with autotune(): with autotune():
torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w) torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w)
# This tests the logic of print_profiling_cache and print_statistics
AutoTuner.get().print_profiling_cache()
AutoTuner.get().print_statistics()
m = M * 2 m = M * 2
while m >= 1: while m >= 1:
best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic( best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic(