[https://nvbugs/5488582][fix] Cherry-pick 7495: Avoid unexpected Triton recompilation in DG fused_moe (#7708)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-09-17 09:00:28 +08:00 committed by GitHub
parent 8bdbb48264
commit 6313c9799c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -220,7 +220,7 @@ def _preprocess_after_permute_kernel(
expert_offsets_ptr,
masked_m_ptr,
token_map_ptr,
TOTAL_TOKENS: tl.constexpr,
total_tokens,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
@ -228,7 +228,7 @@ def _preprocess_after_permute_kernel(
pid_y = tl.program_id(1)
if pid_y == 0:
token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
token_mask = token_offsets < TOTAL_TOKENS
token_mask = token_offsets < total_tokens
# get expert_id for each token in the block
expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32)
found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1)
@ -287,7 +287,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
expert_first_token_offset_tensor,
masked_m,
token_to_expert_map,
TOTAL_TOKENS=total_tokens,
total_tokens,
NUM_EXPERTS=num_experts,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)