From 6313c9799c4ee1d5a81f4a4f24164308ff507f7e Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:00:28 +0800 Subject: [PATCH] [https://nvbugs/5488582][fix] Cherry-pick 7495: Avoid unexpected Triton recompilation in DG fused_moe (#7708) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 71493b2612..4c9bd9bb78 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -220,7 +220,7 @@ def _preprocess_after_permute_kernel( expert_offsets_ptr, masked_m_ptr, token_map_ptr, - TOTAL_TOKENS: tl.constexpr, + total_tokens, NUM_EXPERTS: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, ): @@ -228,7 +228,7 @@ def _preprocess_after_permute_kernel( pid_y = tl.program_id(1) if pid_y == 0: token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - token_mask = token_offsets < TOTAL_TOKENS + token_mask = token_offsets < total_tokens # get expert_id for each token in the block expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32) found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1) @@ -287,7 +287,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor, expert_first_token_offset_tensor, masked_m, token_to_expert_map, - TOTAL_TOKENS=total_tokens, + total_tokens, NUM_EXPERTS=num_experts, BLOCK_SIZE_M=BLOCK_SIZE_M, )