mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6019] feat: Remove cutlass min latency code from AutoTuner. (#5394)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
942841417e
commit
9ee33605bb
@ -22,12 +22,9 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
|
||||
class MoERunner(TunableRunner):
|
||||
# avoid overhead of creating a new runner in forward pass
|
||||
runner_dict = dict()
|
||||
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(
|
||||
DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192),
|
||||
lambda x: min(last_positive_power_of_2(x), 8192)),
|
||||
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
|
||||
))
|
||||
lambda x: min(last_positive_power_of_2(x), 8192)), ))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -44,6 +41,7 @@ class MoERunner(TunableRunner):
|
||||
enable_alltoall: bool,
|
||||
use_deepseek_fp8_block_scale: bool,
|
||||
use_w4a8_group_scaling: bool,
|
||||
min_latency_mode: bool,
|
||||
):
|
||||
self.x_dtype = x_dtype
|
||||
self.weight_dtype = weight_dtype
|
||||
@ -58,7 +56,7 @@ class MoERunner(TunableRunner):
|
||||
self.enable_alltoall = enable_alltoall
|
||||
self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
|
||||
self.use_w4a8_group_scaling = use_w4a8_group_scaling
|
||||
|
||||
self.min_latency_mode = min_latency_mode
|
||||
instance_key = (x_dtype, weight_dtype, output_dtype,
|
||||
use_deepseek_fp8_block_scale, use_w4a8_group_scaling)
|
||||
|
||||
@ -74,22 +72,7 @@ class MoERunner(TunableRunner):
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
x, _, _, min_latency_mode_tensor = inputs
|
||||
min_latency_mode = min_latency_mode_tensor.size(0) == 1
|
||||
m = x.shape[0]
|
||||
|
||||
# Only profile m <= 128 for min latency mode = True
|
||||
# Profile all valid buckets for min latency mode = False
|
||||
# TODO: min_latency_mode = True will cause the following error:
|
||||
# Cannot profile configuration 4: Cutlass GEMM Tactic
|
||||
# [TensorRT-LLM][ERROR] Assertion failed: Failed to initialize cutlass TMA WS grouped gemm.
|
||||
# Should be fixed in the moe_kernels in the future.
|
||||
invalid = (m > 128 and
|
||||
min_latency_mode) or (m <= 128 and min_latency_mode and
|
||||
(not self.weight_dtype == torch.int64))
|
||||
|
||||
return [] if invalid else list(
|
||||
range(self.fused_moe_runner.get_tactic_num()))
|
||||
return range(self.fused_moe_runner.get_tactic_num())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -98,8 +81,7 @@ class MoERunner(TunableRunner):
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
):
|
||||
x, fc1_expert_weights, fc2_expert_weights, min_latency_mode_tensor = inputs
|
||||
min_latency_mode = min_latency_mode_tensor.size(0) == 1
|
||||
x, fc1_expert_weights, fc2_expert_weights = inputs
|
||||
# determine if we should use min latency mode according to the profiled seq len
|
||||
self.fused_moe_runner.run_gemm_profile(
|
||||
x,
|
||||
@ -113,7 +95,7 @@ class MoERunner(TunableRunner):
|
||||
self.cluster_size,
|
||||
self.cluster_rank,
|
||||
self.enable_alltoall,
|
||||
min_latency_mode,
|
||||
self.min_latency_mode,
|
||||
gemm_idx,
|
||||
tactic,
|
||||
do_preparation,
|
||||
@ -122,13 +104,11 @@ class MoERunner(TunableRunner):
|
||||
@classmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def refine_tuning_config(cls, tune_max_num_tokens: int):
|
||||
cls.tuning_config = TuningConfig(dynamic_tensor_specs=(
|
||||
DynamicTensorSpec(
|
||||
cls.tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets(
|
||||
tune_max_num_tokens), lambda x: min(
|
||||
last_positive_power_of_2(x), tune_max_num_tokens)),
|
||||
DynamicTensorSpec(3, 0, (0, ), lambda x: x),
|
||||
))
|
||||
last_positive_power_of_2(x), tune_max_num_tokens)), ))
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::fused_moe", mutates_args=())
|
||||
@ -157,9 +137,6 @@ def fused_moe(
|
||||
tuner = AutoTuner.get()
|
||||
MoERunner.refine_tuning_config(tune_max_num_tokens)
|
||||
|
||||
# TODO: set min_latency_mode always to False due to the error in the moe_kernels
|
||||
min_latency_tensor = torch.empty(0)
|
||||
|
||||
# allocate workspace for profiling
|
||||
moe_runner = MoERunner(
|
||||
x_dtype=input.dtype,
|
||||
@ -175,13 +152,14 @@ def fused_moe(
|
||||
enable_alltoall=enable_alltoall,
|
||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
|
||||
use_w4a8_group_scaling=use_w4a8_group_scaling,
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
|
||||
_, gemm_tactic_1 = tuner.choose_one(
|
||||
"trtllm::fused_moe::gemm1",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
|
||||
[input, fc1_expert_weights, fc2_expert_weights],
|
||||
gemm_idx=1,
|
||||
)
|
||||
|
||||
@ -189,7 +167,7 @@ def fused_moe(
|
||||
"trtllm::fused_moe::gemm2",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[input, fc1_expert_weights, fc2_expert_weights, min_latency_tensor],
|
||||
[input, fc1_expert_weights, fc2_expert_weights],
|
||||
gemm_idx=2,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user