mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Fix on-disk cache and revise logger/statistics for AutoTuner. (#9211)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
c87e81c1d8
commit
60c43a200a
@ -98,6 +98,7 @@ class TuningConfig:
|
|||||||
use_cold_l2_cache (bool): Whether to use cold L2 cache.
|
use_cold_l2_cache (bool): Whether to use cold L2 cache.
|
||||||
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
|
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
|
||||||
Notice that not all tuning processes can benefit from this feature.
|
Notice that not all tuning processes can benefit from this feature.
|
||||||
|
use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
|
||||||
"""
|
"""
|
||||||
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
|
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
|
||||||
constraint_specs: Tuple[ConstraintSpec, ...] = ()
|
constraint_specs: Tuple[ConstraintSpec, ...] = ()
|
||||||
@ -211,8 +212,16 @@ class TunableRunner(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash(tuple(self.__dict__.values()))
|
"""
|
||||||
|
Returns a tuple of the unique id of the runner. The unique id will be converted to a string for the cache key.
|
||||||
|
A common practice is to return a tuple of the runner's attributes, for example:
|
||||||
|
return (self.output_dtype, self.attribute_1, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The unique id of the runner, which can be converted to a string for the cache key.
|
||||||
|
"""
|
||||||
|
return tuple(self.__dict__.values())
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -226,7 +235,6 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
|
|||||||
# if the rank-specific file exists, load it
|
# if the rank-specific file exists, load it
|
||||||
file_exists = os.path.exists(cache_path_no_ext_rank)
|
file_exists = os.path.exists(cache_path_no_ext_rank)
|
||||||
# if the rank-specific file exists, do not enable tuning mode
|
# if the rank-specific file exists, do not enable tuning mode
|
||||||
tune_required = tune_required and not os.path.exists(cache_path)
|
|
||||||
if file_exists:
|
if file_exists:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
|
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
|
||||||
@ -259,8 +267,8 @@ class AutoTunerStatistics:
|
|||||||
cache_misses (int): Number of cache misses requiring fallback
|
cache_misses (int): Number of cache misses requiring fallback
|
||||||
cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses
|
cache_miss_config_collection (Dict[str, Set[OptimizationProfile]]): Collection of configs that caused cache misses
|
||||||
failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation
|
failed_profiling_count (Dict[str, int]): Number of failed profiling attempts per operation
|
||||||
tuned_op_total_configs (Dict[str, int]): Total configurations tried per operation
|
tuned_op_profiled_configs (Dict[str, int]): Profiled configurations per operation
|
||||||
tuned_op_successful_configs (Dict[str, int]): Successful configurations per operation
|
tuned_op_time_cost (Dict[str, float]): Time cost per operation
|
||||||
"""
|
"""
|
||||||
cache_misses: int = 0
|
cache_misses: int = 0
|
||||||
cache_miss_config_collection: Dict[str,
|
cache_miss_config_collection: Dict[str,
|
||||||
@ -268,8 +276,8 @@ class AutoTunerStatistics:
|
|||||||
failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner,
|
failed_profiling_count: Dict[str, Set[Tuple[str, TunableRunner,
|
||||||
OptimizationProfile]]] = field(
|
OptimizationProfile]]] = field(
|
||||||
default_factory=dict)
|
default_factory=dict)
|
||||||
tuned_op_total_configs: Dict[str, int] = field(default_factory=dict)
|
tuned_op_profiled_configs: Dict[str, int] = field(default_factory=dict)
|
||||||
tuned_op_successful_configs: Dict[str, int] = field(default_factory=dict)
|
tuned_op_time_cost: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return a string representation of collected statistics.
|
"""Return a string representation of collected statistics.
|
||||||
@ -284,22 +292,23 @@ class AutoTunerStatistics:
|
|||||||
for profile in sorted(profiles, key=str):
|
for profile in sorted(profiles, key=str):
|
||||||
stats_str += f" - Config: {profile}\n"
|
stats_str += f" - Config: {profile}\n"
|
||||||
|
|
||||||
if self.tuned_op_total_configs:
|
if self.tuned_op_profiled_configs:
|
||||||
stats_str += "Tuned operations:\n"
|
stats_str += "Tuned operations:\n"
|
||||||
for op in sorted(self.tuned_op_total_configs.keys()):
|
for op in sorted(self.tuned_op_profiled_configs.keys()):
|
||||||
total = self.tuned_op_total_configs[op]
|
successful = self.tuned_op_profiled_configs[op]
|
||||||
successful = self.tuned_op_successful_configs.get(op, 0)
|
failed = len(self.failed_profiling_count[op])
|
||||||
failed = len(self.failed_profiling_count.get(op, set()))
|
|
||||||
success_rate = (successful / total * 100) if total > 0 else 0
|
|
||||||
stats_str += f" {op}:\n"
|
stats_str += f" {op}:\n"
|
||||||
stats_str += f" - Total configs tried: {total}\n"
|
|
||||||
stats_str += f" - Successful configs: {successful}\n"
|
stats_str += f" - Successful configs: {successful}\n"
|
||||||
stats_str += f" - Failed profiling count: {failed}\n"
|
stats_str += f" - Failed profiling count: {failed}\n"
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
stats_str += f" - Failed profiling combinations:\n"
|
stats_str += f" - Failed profiling combinations:\n"
|
||||||
for failed_key in self.failed_profiling_count[op]:
|
for failed_key in self.failed_profiling_count[op]:
|
||||||
stats_str += f" - {failed_key}\n"
|
stats_str += f" - {failed_key}\n"
|
||||||
stats_str += f" - Success rate: {success_rate:.1f}%\n"
|
|
||||||
|
if self.tuned_op_time_cost:
|
||||||
|
stats_str += "Tuned operations time cost:\n"
|
||||||
|
for op in sorted(self.tuned_op_time_cost.keys()):
|
||||||
|
stats_str += f" {op}: {self.tuned_op_time_cost[op] * 1000:.4f} milliseconds\n"
|
||||||
|
|
||||||
return stats_str
|
return stats_str
|
||||||
|
|
||||||
@ -374,7 +383,7 @@ class AutoTunerProfilingCache:
|
|||||||
return (
|
return (
|
||||||
custom_op,
|
custom_op,
|
||||||
runner.__class__.__name__,
|
runner.__class__.__name__,
|
||||||
hash(runner),
|
str(runner.unique_id()),
|
||||||
AutoTuner.get()._find_nearest_profile(
|
AutoTuner.get()._find_nearest_profile(
|
||||||
input_shapes,
|
input_shapes,
|
||||||
tuning_config.dynamic_tensor_specs,
|
tuning_config.dynamic_tensor_specs,
|
||||||
@ -546,6 +555,11 @@ class AutoTuner:
|
|||||||
# Last captured choose_one() contexts
|
# Last captured choose_one() contexts
|
||||||
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
|
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
|
||||||
|
|
||||||
|
# Increase log level for AutoTuner associated logger
|
||||||
|
self._log_level_to_info = os.getenv(
|
||||||
|
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
|
||||||
|
self._debug_logger = logger.info if self._log_level_to_info else logger.debug
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls):
|
def get(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@ -726,11 +740,14 @@ class AutoTuner:
|
|||||||
assert all([isinstance(r, TunableRunner) for r in runners]), \
|
assert all([isinstance(r, TunableRunner) for r in runners]), \
|
||||||
"All Given runners must be subclass of TunableRunner"
|
"All Given runners must be subclass of TunableRunner"
|
||||||
|
|
||||||
|
tuning_start_time = time.perf_counter()
|
||||||
profiles = self._optimization_profiles(tuning_config, inputs)
|
profiles = self._optimization_profiles(tuning_config, inputs)
|
||||||
|
|
||||||
# Record the total configs to try
|
# Initialize the statistics for the custom_op
|
||||||
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
|
if custom_op not in self.stats.tuned_op_profiled_configs:
|
||||||
|
self.stats.tuned_op_profiled_configs[custom_op] = 0
|
||||||
|
if custom_op not in self.stats.failed_profiling_count:
|
||||||
|
self.stats.failed_profiling_count[custom_op] = set()
|
||||||
new_tuning_failure_occured = False
|
new_tuning_failure_occured = False
|
||||||
|
|
||||||
for p in profiles:
|
for p in profiles:
|
||||||
@ -746,16 +763,15 @@ class AutoTuner:
|
|||||||
cache_key = self.profiling_cache.get_cache_key(
|
cache_key = self.profiling_cache.get_cache_key(
|
||||||
custom_op, runners[best_runner_id], p.get_opt_shapes(),
|
custom_op, runners[best_runner_id], p.get_opt_shapes(),
|
||||||
tuning_config)
|
tuning_config)
|
||||||
|
|
||||||
|
self._debug_logger(
|
||||||
|
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
|
||||||
|
)
|
||||||
# inspect call stack
|
# inspect call stack
|
||||||
self.profiling_cache[cache_key] = (best_runner_id,
|
self.profiling_cache[cache_key] = (best_runner_id,
|
||||||
best_tactic, min_time)
|
best_tactic, min_time)
|
||||||
|
|
||||||
self.stats.tuned_op_successful_configs[
|
self.stats.tuned_op_profiled_configs[custom_op] += 1
|
||||||
custom_op] = self.stats.tuned_op_successful_configs.get(
|
|
||||||
custom_op, 0) + 1
|
|
||||||
logger.debug(
|
|
||||||
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
|
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
|
||||||
@ -782,6 +798,10 @@ class AutoTuner:
|
|||||||
_, runner_id, tactic, _ = self.profiling_cache.search_cache(
|
_, runner_id, tactic, _ = self.profiling_cache.search_cache(
|
||||||
custom_op, runners, input_shapes, tuning_config)
|
custom_op, runners, input_shapes, tuning_config)
|
||||||
|
|
||||||
|
tuning_end_time = time.perf_counter()
|
||||||
|
self.stats.tuned_op_time_cost[
|
||||||
|
custom_op] = self.stats.tuned_op_time_cost.get(
|
||||||
|
custom_op, 0) + tuning_end_time - tuning_start_time
|
||||||
return (runners[runner_id], tactic)
|
return (runners[runner_id], tactic)
|
||||||
|
|
||||||
def _profile_runners(
|
def _profile_runners(
|
||||||
@ -832,14 +852,13 @@ class AutoTuner:
|
|||||||
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.",
|
f"[Autotuner] Failed when profiling runner={runner}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details.",
|
||||||
key=(custom_op, "warning_autotuning_profile_failure"),
|
key=(custom_op, "warning_autotuning_profile_failure"),
|
||||||
)
|
)
|
||||||
logger.debug_once(
|
(logger.info_once
|
||||||
f"[Autotuner] Exception captured: {e}",
|
if self._log_level_to_info else logger.debug_once)(
|
||||||
key=(custom_op, "debug_autotuning_exception"),
|
f"[Autotuner] Exception captured: {e}",
|
||||||
)
|
key=(custom_op, "debug_autotuning_exception"),
|
||||||
|
)
|
||||||
|
|
||||||
# Record the failed profiling combinations
|
# Record the failed profiling combinations
|
||||||
if custom_op not in self.stats.failed_profiling_count:
|
|
||||||
self.stats.failed_profiling_count[custom_op] = set()
|
|
||||||
self.stats.failed_profiling_count[custom_op].add(
|
self.stats.failed_profiling_count[custom_op].add(
|
||||||
self.profiling_cache.get_cache_key(
|
self.profiling_cache.get_cache_key(
|
||||||
custom_op, runner, profile.get_opt_shapes(),
|
custom_op, runner, profile.get_opt_shapes(),
|
||||||
@ -957,7 +976,7 @@ class AutoTuner:
|
|||||||
avg_time = pure_profile(stream, self.repeat)
|
avg_time = pure_profile(stream, self.repeat)
|
||||||
|
|
||||||
shapes = self._get_input_sizes(inputs)
|
shapes = self._get_input_sizes(inputs)
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms."
|
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time:.6f}ms."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1043,7 +1062,7 @@ class AutoTuner:
|
|||||||
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
|
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
|
||||||
min_value, opt_value, max_value)
|
min_value, opt_value, max_value)
|
||||||
generated_profiles.append(p)
|
generated_profiles.append(p)
|
||||||
logger.debug(f"[Autotuner] Generated profile: {p}")
|
self._debug_logger(f"[Autotuner] Generated profile: {p}")
|
||||||
return generated_profiles
|
return generated_profiles
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1159,7 +1178,7 @@ class AutoTuner:
|
|||||||
input.element_size() if isinstance(input, torch.Tensor) else 0
|
input.element_size() if isinstance(input, torch.Tensor) else 0
|
||||||
for input in inputs)
|
for input in inputs)
|
||||||
if one_buffer_bytes <= 0:
|
if one_buffer_bytes <= 0:
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
|
"[Autotuner] No tensor inputs or zero-sized tensors; falling back to single-batch profiling."
|
||||||
)
|
)
|
||||||
return [inputs]
|
return [inputs]
|
||||||
@ -1174,7 +1193,7 @@ class AutoTuner:
|
|||||||
list(t.clone() if isinstance(t, torch.Tensor) else t
|
list(t.clone() if isinstance(t, torch.Tensor) else t
|
||||||
for t in inputs))
|
for t in inputs))
|
||||||
|
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
|
f"[Autotuner] use_cold_l2_cache={tuning_config.use_cold_l2_cache}, use {num_buffers} different tensors for profiling"
|
||||||
)
|
)
|
||||||
return inputs_list
|
return inputs_list
|
||||||
@ -1188,16 +1207,23 @@ class AutoTuner:
|
|||||||
self.stats = AutoTunerStatistics()
|
self.stats = AutoTunerStatistics()
|
||||||
|
|
||||||
def print_profiling_cache(self):
|
def print_profiling_cache(self):
|
||||||
logger.debug(f"[Autotuner] The profiling_cache entries:")
|
self._debug_logger(f"[Autotuner] The profiling_cache entries:")
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
|
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
|
||||||
)
|
)
|
||||||
for key, value in self.profiling_cache.cache.items():
|
for key, value in self.profiling_cache.cache.items():
|
||||||
runner_id, tactic, min_time = value
|
runner_id, tactic, min_time = value
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
|
f"[Autotuner] {key}: (runner_id={runner_id}, tactic={tactic}, min_time={min_time})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.print_statistics()
|
||||||
|
|
||||||
|
def print_statistics(self):
|
||||||
|
self._debug_logger(f"[Autotuner] The statistics:")
|
||||||
|
for line in self.stats.__str__().split("\n"):
|
||||||
|
self._debug_logger(line)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def capture(self):
|
def capture(self):
|
||||||
"""Context manager for capturing execution contexts for testing.
|
"""Context manager for capturing execution contexts for testing.
|
||||||
@ -1271,7 +1297,7 @@ class AutoTuner:
|
|||||||
runner_idx = runners.index(runner)
|
runner_idx = runners.index(runner)
|
||||||
runner_tactic_list.append((runner_idx, tactic))
|
runner_tactic_list.append((runner_idx, tactic))
|
||||||
|
|
||||||
logger.debug(
|
self._debug_logger(
|
||||||
f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}")
|
f"[Autotuner][replay]: Testing configuration: {runner_tactic_list}")
|
||||||
|
|
||||||
# Replay the contexts with given (runner, tactic) pairs
|
# Replay the contexts with given (runner, tactic) pairs
|
||||||
|
|||||||
@ -55,13 +55,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
|
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash((self.output_dtype, ))
|
return (self.output_dtype, )
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, self.__class__):
|
|
||||||
return False
|
|
||||||
return self.output_dtype == other.output_dtype
|
|
||||||
|
|
||||||
def get_valid_tactics(
|
def get_valid_tactics(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -93,6 +93,29 @@ class MoERunner(TunableRunner):
|
|||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"]))
|
return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"]))
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.x_dtype,
|
||||||
|
self.weight_dtype,
|
||||||
|
self.output_dtype,
|
||||||
|
self.top_k,
|
||||||
|
self.tp_size,
|
||||||
|
self.tp_rank,
|
||||||
|
self.ep_size,
|
||||||
|
self.ep_rank,
|
||||||
|
self.cluster_size,
|
||||||
|
self.cluster_rank,
|
||||||
|
self.enable_alltoall,
|
||||||
|
self.use_deepseek_fp8_block_scale,
|
||||||
|
self.use_w4_group_scaling,
|
||||||
|
self.use_int8_woq_per_channel,
|
||||||
|
self.use_mxfp8_act_scaling,
|
||||||
|
self.min_latency_mode,
|
||||||
|
self.use_fused_finalize,
|
||||||
|
self.activation_type,
|
||||||
|
self.unpadded_hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
@ -316,6 +339,12 @@ class FP8RowwiseGemmRunner(TunableRunner):
|
|||||||
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
|
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
|
||||||
instance_key]
|
instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.to_userbuffers,
|
||||||
|
self.output_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))
|
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))
|
||||||
@ -398,6 +427,12 @@ class FP4GemmRunner(TunableRunner):
|
|||||||
output_dtype, int(fp4_gemm_type))
|
output_dtype, int(fp4_gemm_type))
|
||||||
self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key]
|
self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.to_userbuffers,
|
||||||
|
self.output_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
return list(range(self.fp4_gemm_runner.get_num_configs()))
|
return list(range(self.fp4_gemm_runner.get_num_configs()))
|
||||||
@ -447,6 +482,12 @@ class CublasLtFP4GemmRunner(TunableRunner):
|
|||||||
|
|
||||||
self.cublaslt_runner = CublasLtFP4GemmRunner.runner_dict[instance_key]
|
self.cublaslt_runner = CublasLtFP4GemmRunner.runner_dict[instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return hash((
|
||||||
|
self.to_userbuffers,
|
||||||
|
self.output_dtype,
|
||||||
|
))
|
||||||
|
|
||||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
"""Get all valid tactics (algorithms) from cuBLASLt heuristic."""
|
"""Get all valid tactics (algorithms) from cuBLASLt heuristic."""
|
||||||
@ -592,6 +633,15 @@ class FP8BatchedGemmRunner(TunableRunner):
|
|||||||
|
|
||||||
self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key]
|
self.kernel_runner = FP8BatchedGemmRunner.runner_dict[instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.output_dtype,
|
||||||
|
self.use_deep_seek_fp8,
|
||||||
|
self.low_latency_kernel,
|
||||||
|
self.tile_size,
|
||||||
|
self.epilogue_tile_m,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
@ -827,6 +877,12 @@ class WeightOnlyQuantGemmRunner(TunableRunner):
|
|||||||
self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[
|
self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[
|
||||||
instance_key]
|
instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.output_dtype,
|
||||||
|
self.to_userbuffers,
|
||||||
|
)
|
||||||
|
|
||||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
return list(range(self.weight_only_quant_gemm_runner.get_num_configs()))
|
return list(range(self.weight_only_quant_gemm_runner.get_num_configs()))
|
||||||
@ -894,6 +950,9 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
|
|||||||
|
|
||||||
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
|
def __init__(self, activation_dtype: torch.dtype, output_dtype: torch.dtype,
|
||||||
quant_mode: int):
|
quant_mode: int):
|
||||||
|
self.activation_dtype = activation_dtype
|
||||||
|
self.output_dtype = output_dtype
|
||||||
|
self.quant_mode = quant_mode
|
||||||
instance_key = (activation_dtype, output_dtype, quant_mode)
|
instance_key = (activation_dtype, output_dtype, quant_mode)
|
||||||
if instance_key not in FinegrainedMixedDtypeGemm._runner_dict:
|
if instance_key not in FinegrainedMixedDtypeGemm._runner_dict:
|
||||||
FinegrainedMixedDtypeGemm._runner_dict[
|
FinegrainedMixedDtypeGemm._runner_dict[
|
||||||
@ -902,6 +961,13 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
|
|||||||
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
|
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
|
||||||
instance_key]
|
instance_key]
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.activation_dtype,
|
||||||
|
self.output_dtype,
|
||||||
|
self.quant_mode,
|
||||||
|
)
|
||||||
|
|
||||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||||
return list(
|
return list(
|
||||||
@ -1012,6 +1078,12 @@ class fp8SwapABGemmRunner(TunableRunner):
|
|||||||
self.output_dtype = output_dtype
|
self.output_dtype = output_dtype
|
||||||
self.disable_ue8m0_cast = disable_ue8m0_cast
|
self.disable_ue8m0_cast = disable_ue8m0_cast
|
||||||
|
|
||||||
|
def unique_id(self):
|
||||||
|
return (
|
||||||
|
self.output_dtype,
|
||||||
|
self.disable_ue8m0_cast,
|
||||||
|
)
|
||||||
|
|
||||||
def get_valid_tactics(
|
def get_valid_tactics(
|
||||||
self,
|
self,
|
||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
|
|||||||
@ -191,24 +191,10 @@ class FP4BlockScaleMoERunner(TunableRunner):
|
|||||||
FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config(
|
FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
# The unique_id is used by the autotuner to get the cache key, so we hash on members
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2
|
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing type does not matter
|
||||||
# so the routing type does not matter
|
def unique_id(self):
|
||||||
def __hash__(self):
|
return (self.top_k, self.intermediate_size, self.local_num_experts)
|
||||||
return hash((
|
|
||||||
self.top_k,
|
|
||||||
self.intermediate_size,
|
|
||||||
self.local_num_experts,
|
|
||||||
))
|
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, FP4BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = ()
|
instance_key = ()
|
||||||
@ -558,24 +544,11 @@ class FP8BlockScaleMoERunner(TunableRunner):
|
|||||||
FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config(
|
FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
# The unique_id is used by the autotuner to get the cache key, so we hash on members
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
||||||
# type does not matter
|
# type does not matter
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash((
|
return (self.top_k, self.intermediate_size, self.local_num_experts)
|
||||||
self.top_k,
|
|
||||||
self.intermediate_size,
|
|
||||||
self.local_num_experts,
|
|
||||||
))
|
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, FP8BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = ()
|
instance_key = ()
|
||||||
@ -845,30 +818,18 @@ class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner):
|
|||||||
MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
|
MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
# The unique_id is used by the autotuner to get the cache key, so we hash on members
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
||||||
# type does not matter
|
# type does not matter
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash((
|
return (
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.valid_hidden_size,
|
self.valid_hidden_size,
|
||||||
self.valid_intermediate_size,
|
self.valid_intermediate_size,
|
||||||
self.local_num_experts,
|
self.local_num_experts,
|
||||||
self.act_type,
|
self.act_type,
|
||||||
))
|
)
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, MxE4m3MxE2m1BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.valid_hidden_size == other.valid_hidden_size and
|
|
||||||
self.valid_intermediate_size == other.valid_intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts
|
|
||||||
and self.act_type == other.act_type)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = (self.act_type, True)
|
instance_key = (self.act_type, True)
|
||||||
@ -1145,30 +1106,18 @@ class E4m3MxE2m1BlockScaleMoERunner(TunableRunner):
|
|||||||
E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
|
E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
# The unique_id is used by the autotuner to get the cache key, so we hash on members
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
||||||
# type does not matter
|
# type does not matter
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash((
|
return (
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.valid_hidden_size,
|
self.valid_hidden_size,
|
||||||
self.valid_intermediate_size,
|
self.valid_intermediate_size,
|
||||||
self.local_num_experts,
|
self.local_num_experts,
|
||||||
self.act_type,
|
self.act_type,
|
||||||
))
|
)
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, E4m3MxE2m1BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.valid_hidden_size == other.valid_hidden_size and
|
|
||||||
self.valid_intermediate_size == other.valid_intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts
|
|
||||||
and self.act_type == other.act_type)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = (self.act_type, False)
|
instance_key = (self.act_type, False)
|
||||||
@ -1425,10 +1374,10 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
|
|||||||
Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(
|
Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
# The unique_id is used by the autotuner to get the cache key, so we hash on members
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
# that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing
|
||||||
# type does not matter
|
# type does not matter
|
||||||
def __hash__(self):
|
def unique_id(self):
|
||||||
return hash((
|
return hash((
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
@ -1438,18 +1387,6 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
|
|||||||
self.act_type,
|
self.act_type,
|
||||||
))
|
))
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, Bf16MxE2m1BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.valid_hidden_size == other.valid_hidden_size and
|
|
||||||
self.valid_intermediate_size == other.valid_intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts
|
|
||||||
and self.act_type == other.act_type)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = (self.act_type, )
|
instance_key = (self.act_type, )
|
||||||
if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict:
|
if instance_key not in Bf16MxE2m1BlockScaleMoERunner.runner_dict:
|
||||||
@ -1695,26 +1632,13 @@ class FP8FP4BlockScaleMoERunner(TunableRunner):
|
|||||||
FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config(
|
FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config(
|
||||||
)
|
)
|
||||||
|
|
||||||
# The hash is used by the autotuner to get the cache key, so we hash on members
|
def unique_id(self):
|
||||||
# that influence tactic validity here. e.g. we are tuning FC1 and FC2
|
return (
|
||||||
# so the routing type does not matter
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((
|
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
self.local_num_experts,
|
self.local_num_experts,
|
||||||
self.act_type,
|
self.act_type,
|
||||||
))
|
)
|
||||||
|
|
||||||
# __eq__ and __hash__ must agree
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, FP8FP4BlockScaleMoERunner):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.top_k == other.top_k
|
|
||||||
and self.intermediate_size == other.intermediate_size
|
|
||||||
and self.local_num_experts == other.local_num_experts
|
|
||||||
and self.act_type == other.act_type)
|
|
||||||
|
|
||||||
def get_runner(self):
|
def get_runner(self):
|
||||||
instance_key = (self.act_type, )
|
instance_key = (self.act_type, )
|
||||||
|
|||||||
@ -140,6 +140,10 @@ def test_autotuner_cache_basic():
|
|||||||
with autotune():
|
with autotune():
|
||||||
torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w)
|
torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w)
|
||||||
|
|
||||||
|
# This tests the logic of print_profiling_cache and print_statistics
|
||||||
|
AutoTuner.get().print_profiling_cache()
|
||||||
|
AutoTuner.get().print_statistics()
|
||||||
|
|
||||||
m = M * 2
|
m = M * 2
|
||||||
while m >= 1:
|
while m >= 1:
|
||||||
best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic(
|
best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user