chore: Improve the AutoTuner log information. (#6368)

* Change the fallback alert from DEBUG to WARNING level and only do it once.
* Add debug information for profiling cache right after the warmup phase.
* Change the level of exception message during tactic profiling from ERROR to WARNING level. All exception details are pushed to the DEBUG level.
* Other trivial refinements and cleanups.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-08-01 09:19:52 +08:00 committed by GitHub
parent 2eca0d5925
commit 00059de380
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 21 deletions

View File

@ -197,13 +197,13 @@ def autotune(tune_mode: bool = True):
AutoTuner.get().is_tuning_mode = tune_mode
autotune_enabled = tune_mode and not old_mode
if autotune_enabled:
logger.info("[Autotuner]: Autotuning process starts ...")
logger.info("[Autotuner] Autotuning process starts ...")
try:
yield
finally:
AutoTuner.get().is_tuning_mode = old_mode
if autotune_enabled:
logger.info("[Autotuner]: Autotuning process ends")
logger.info("[Autotuner] Autotuning process ends")
@dataclass
@ -350,16 +350,11 @@ class AutoTuner:
runner = runners[runner_id]
# TODO: check the stored runner and tactic can implement this shape here
# Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.
# Record the cache miss config.
# Expect no cache miss in inference. Thus, any cache miss should be recorded.
if not is_cache_hit:
logger.debug(
f"[AutoTunner]: Using fallback tactic for {custom_op} with input shapes {input_shapes}"
)
logger.debug(
f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}"
)
if not is_cache_hit and len(self.profiling_cache) > 0:
# Only log once for each custom op and only when cache is not empty
logger.warning_once(
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
key=(custom_op))
return runner, tactic
assert len(runners) > 0, "At least one runner is required"
@ -370,6 +365,8 @@ class AutoTuner:
# Record the total configs to try
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
new_tuning_failure_occured = False
for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, runner_id, tactic, _ = self.search_cache(
@ -396,11 +393,13 @@ class AutoTuner:
except Exception as e:
shapes = self._get_input_sizes(tensors)
logger.error(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
logger.warning(
f"[Autotuner] Failed when profiling runner={r}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details."
)
logger.debug(f"[Autotuner] Exception captured: {e}")
# Record the failed profiling combinations
new_tuning_failure_occured = True
if custom_op not in self.stats.failed_profiling_count:
self.stats.failed_profiling_count[
custom_op] = set()
@ -426,8 +425,25 @@ class AutoTuner:
custom_op] = self.stats.tuned_op_successful_configs.get(
custom_op, 0) + 1
logger.debug(
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
f"[Autotuner] Profiling runner={runners[runner_id]}, tactic={tactic} for cache_key={cache_key}."
)
else:
logger.warning(
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
f"At least one valid (runner, tactic) pair is required. "
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash."
)
# If failed profiling tactics occurs, log the error.
if new_tuning_failure_occured:
logger.warning(
f"[Autotuner] New tuning error occurs:"
f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. "
f"This will not block the tuning process. "
f"Please set TLLM_LOG_LEVEL=WARNING to find out when the tactic profiling fails. "
f"Set TLLM_LOG_LEVEL=DEBUG to get more details of the failures."
)
# Get the best runner and tactic from cache
# If no valid tactic is found, the fallback runner and tactic will be used
@ -487,7 +503,7 @@ class AutoTuner:
shapes = self._get_input_sizes(inputs)
logger.debug(
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time}ms."
)
return avg_time
@ -557,7 +573,7 @@ class AutoTuner:
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
min_value, opt_value, max_value)
generated_profiles.append(p)
logger.debug(f"[Autotuner]: generated profile: {p}")
logger.debug(f"[Autotuner] Generated profile: {p}")
return generated_profiles
@classmethod
@ -649,3 +665,12 @@ class AutoTuner:
def reset_statistics(self) -> None:
"""Reset all statistics counters."""
self.stats = AutoTunerStatistics()
def print_profiling_cache(self):
logger.debug(f"[Autotuner] The profiling_cache entries:")
logger.debug(
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
)
for key, value in self.profiling_cache.items():
runner_id, tactic, _ = value
logger.debug(f"[Autotuner] {key}: ({runner_id}, {tactic})")

View File

@ -695,15 +695,16 @@ class PyTorchModelEngine(ModelEngine):
# No KV cache space!
pass
else:
logger.info(
f"Run autotuning warmup for batch size={1}")
self.forward(batch,
new_tensors_device=None,
resource_manager=resource_manager)
torch.cuda.synchronize()
logger.info(f"Autotuner Cache size after warmup " +
str(len(AutoTuner.get().profiling_cache)))
logger.info(
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
)
AutoTuner.get().print_profiling_cache()
if not (self._run_cuda_graphs
or self._torch_compile_piecewise_cuda_graph):