[TRTLLM-6019] feat: Remove cutlass min latency code from AutoTuner. (#5394)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-06-26 13:12:03 +08:00 committed by GitHub
parent 942841417e
commit 9ee33605bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,12 +22,9 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
class MoERunner(TunableRunner):
# avoid overhead of creating a new runner in forward pass
runner_dict = dict()
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192),
lambda x: min(last_positive_power_of_2(x), 8192)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
lambda x: min(last_positive_power_of_2(x), 8192)), ))
def __init__(
self,
@ -44,6 +41,7 @@ class MoERunner(TunableRunner):
enable_alltoall: bool,
use_deepseek_fp8_block_scale: bool,
use_w4a8_group_scaling: bool,
min_latency_mode: bool,
):
self.x_dtype = x_dtype
self.weight_dtype = weight_dtype
@ -58,7 +56,7 @@ class MoERunner(TunableRunner):
self.enable_alltoall = enable_alltoall
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
self.use_w4a8_group_scaling = use_w4a8_group_scaling
self.min_latency_mode = min_latency_mode
instance_key = (x_dtype, weight_dtype, output_dtype,
use_deepseek_fp8_block_scale, use_w4a8_group_scaling)
@ -74,22 +72,7 @@ class MoERunner(TunableRunner):
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
x, _, _, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
m = x.shape[0]
# Only profile m <= 128 for min latency mode = True
# Profile all valid buckets for min latency mode = False
# TODO: min_latency_mode = True will cause the following error:
# Cannot profile configuration 4: Cutlass GEMM Tactic
# [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
# Should be fixed in the moe_kernels in the future.
invalid = (m > 128 and
min_latency_mode) or (m <= 128 and min_latency_mode and
(not self.weight_dtype == torch.int64))
return [] if invalid else list(
range(self.fused_moe_runner.get_tactic_num()))
return range(self.fused_moe_runner.get_tactic_num())
def forward(
self,
@ -98,8 +81,7 @@ class MoERunner(TunableRunner):
tactic: int = -1,
do_preparation: bool = False,
):
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
min_latency_mode = min_latency_mode_tensor.size(0) == 1
x, fc1_expert_weights, fc2_expert_weights = inputs
# determine if we should use min latency mode according to the profiled seq len
self.fused_moe_runner.run_gemm_profile(
x,
@ -113,7 +95,7 @@ class MoERunner(TunableRunner):
self.cluster_size,
self.cluster_rank,
self.enable_alltoall,
min_latency_mode,
self.min_latency_mode,
gemm_idx,
tactic,
do_preparation,
@ -122,13 +104,11 @@ class MoERunner(TunableRunner):
@classmethod
@lru_cache(maxsize=None)
def refine_tuning_config(cls, tune_max_num_tokens: int):
cls.tuning_config = TuningConfig(dynamic_tensor_specs=(
DynamicTensorSpec(
cls.tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets(
tune_max_num_tokens), lambda x: min(
last_positive_power_of_2(x), tune_max_num_tokens)),
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
))
last_positive_power_of_2(x), tune_max_num_tokens)), ))
@torch.library.custom_op("trtllm::fused_moe", mutates_args=())
@ -157,9 +137,6 @@ def fused_moe(
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)
# TODO: set min_latency_mode always to False due to the error in the moe_kernels
min_latency_tensor = torch.empty(0)
# allocate workspace for profiling
moe_runner = MoERunner(
x_dtype=input.dtype,
@ -175,13 +152,14 @@ def fused_moe(
enable_alltoall=enable_alltoall,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=min_latency_mode,
)
_, gemm_tactic_1 = tuner.choose_one(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=1,
)
@ -189,7 +167,7 @@ def fused_moe(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
[input, fc1_expert_weights, fc2_expert_weights],
gemm_idx=2,
)