mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5488582][fix] Cherry-pick 7495: Avoid unexpected Triton recompilation in DG fused_moe (#7708)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
8bdbb48264
commit
6313c9799c
@ -220,7 +220,7 @@ def _preprocess_after_permute_kernel(
|
||||
expert_offsets_ptr,
|
||||
masked_m_ptr,
|
||||
token_map_ptr,
|
||||
TOTAL_TOKENS: tl.constexpr,
|
||||
total_tokens,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
@ -228,7 +228,7 @@ def _preprocess_after_permute_kernel(
|
||||
pid_y = tl.program_id(1)
|
||||
if pid_y == 0:
|
||||
token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
token_mask = token_offsets < TOTAL_TOKENS
|
||||
token_mask = token_offsets < total_tokens
|
||||
# get expert_id for each token in the block
|
||||
expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32)
|
||||
found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1)
|
||||
@ -287,7 +287,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor,
|
||||
expert_first_token_offset_tensor,
|
||||
masked_m,
|
||||
token_to_expert_map,
|
||||
TOTAL_TOKENS=total_tokens,
|
||||
total_tokens,
|
||||
NUM_EXPERTS=num_experts,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user