[TRTLLM-10147][perf] Balanced random MoE workload generator for CuteDSL kernel UT, autotuner and layerwise benchmark (#10279)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2026-01-25 21:02:30 +08:00 committed by GitHub
parent fd7fd8c39d
commit 72ef732bcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 415 additions and 233 deletions

View File

@ -1,4 +1,5 @@
import itertools
import math
from typing import List, Optional, Tuple
import torch
@ -31,13 +32,19 @@ class GroupedGemmInputsHelper:
IDX_A = 0
IDX_SHAPE_INFER = IDX_A # Default: use a tensor for shape inference
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
local_expert_offset: int, tile_size: int):
def __init__(self,
num_experts: int,
top_k: int,
num_local_experts: int,
local_expert_offset: int,
tile_size: int,
seed: int = 515):
self.num_experts = num_experts
self.top_k = top_k
self.num_local_experts = num_local_experts
self.local_expert_offset = local_expert_offset
self.tile_size = tile_size
self.seed = seed
# Padding values should never be accessed.
# Intentionally use a large padding value to expose issues early.
self.pad_val = int(2e9)
@ -82,118 +89,134 @@ class GroupedGemmInputsHelper:
self, input_shapes: List[torch.Size]) -> int:
return self.infer_shape_max_num_tiles(input_shapes) * self.tile_size
def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]:
average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts
balance = 0
num_tokens_per_expert = []
for i in range(self.num_local_experts):
balance += average_num_tokens_per_expert
if balance <= 1e-3:
continue
curr_num_tokens = int(balance) + 1
num_tokens_per_expert.append(curr_num_tokens)
balance -= curr_num_tokens
def generate_num_tokens_per_expert(self,
num_tokens: int,
approx_max_load: bool = False
) -> List[int]:
ep_size = self.num_experts // self.num_local_experts
average_num_tokens_per_rank = num_tokens * self.top_k / ep_size
if approx_max_load:
# https://en.wikipedia.org/wiki/Balls_into_bins_problem
# The constant c can be measured empirically, we choose 1.0 for simplicity.
c = 1.0
extra_num_tokens_on_curr_rank = c * math.sqrt(
average_num_tokens_per_rank * math.log(ep_size))
num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank +
extra_num_tokens_on_curr_rank)
else:
num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank)
num_tokens_on_curr_rank = min(num_tokens * self.top_k,
num_tokens_on_curr_rank)
base, remainder = divmod(num_tokens_on_curr_rank,
self.num_local_experts)
num_tokens_per_expert = [base + 1] * remainder + [base] * (
self.num_local_experts - remainder)
assert len(num_tokens_per_expert) == self.num_local_experts
assert sum(num_tokens_per_expert) == num_tokens_on_curr_rank
return num_tokens_per_expert
def generate_tile_idx_to_group_idx(
self, num_tokens_per_expert: List[int]) -> List[int]:
tile_idx_to_group_idx = []
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
curr_num_tiles = (curr_num_tokens + self.tile_size -
1) // self.tile_size
tile_idx_to_group_idx.extend([i] * curr_num_tiles)
return tile_idx_to_group_idx
def generate_tile_idx_to_mn_limit(
self, num_tokens_per_expert: List[int]) -> List[int]:
tile_idx_to_mn_limit = []
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
curr_num_tiles = (curr_num_tokens + self.tile_size -
1) // self.tile_size
prev_mn_limit = len(tile_idx_to_mn_limit) * self.tile_size
for j in range(curr_num_tiles):
tile_idx_to_mn_limit.append(prev_mn_limit + min(
(j + 1) * self.tile_size, curr_num_tokens))
return tile_idx_to_mn_limit
def generate_permuted_idx_to_expanded_idx(
def generate_token_selected_experts(
self, num_tokens: int,
num_tokens_per_expert: List[int]) -> List[int]:
permuted_idx_to_expanded_idx = []
colmajor_expanded_idx = 0
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
curr_num_tiles = (curr_num_tokens + self.tile_size -
1) // self.tile_size
for j in range(curr_num_tiles * self.tile_size):
if j < curr_num_tokens:
token_idx = colmajor_expanded_idx % num_tokens
topk_idx = colmajor_expanded_idx // num_tokens
expanded_idx = token_idx * self.top_k + topk_idx
permuted_idx_to_expanded_idx.append(expanded_idx)
colmajor_expanded_idx += 1
else:
permuted_idx_to_expanded_idx.append(self.pad_val)
return permuted_idx_to_expanded_idx
num_tokens_per_expert: List[int]) -> torch.Tensor:
"""Balanced random based on rejection sampling.
"""
token_selected_experts = -torch.ones(
num_tokens, self.top_k, dtype=torch.int32)
num_selected_experts = torch.zeros(num_tokens, dtype=torch.int32)
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(self.seed)
selection_orders = [
torch.randperm(num_tokens)
for _ in range(self.num_local_experts)
]
for j, num_tokens_j in enumerate(num_tokens_per_expert):
selection_order_j = selection_orders[j].tolist()
prioritized = torch.nonzero(num_selected_experts <= (
self.top_k - (self.num_experts - j))).squeeze(-1).tolist()
if len(prioritized) > 0:
selection_order_j = prioritized + [
i for i in selection_order_j if i not in prioritized
]
for i in selection_order_j:
if num_selected_experts[i] < self.top_k:
token_selected_experts[
i,
num_selected_experts[i]] = j + self.local_expert_offset
num_selected_experts[i] += 1
num_tokens_j -= 1
if num_tokens_j <= 0:
break
assert ((token_selected_experts
>= 0).sum(dim=-1) == num_selected_experts).all().item()
if self.num_local_experts == self.num_experts:
assert (num_selected_experts == self.top_k).all().item()
else:
assert (num_selected_experts <= self.top_k).all().item()
return token_selected_experts
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs
num_tokens = self.infer_num_tokens(a.size(0))
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
num_tokens_per_expert)
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
num_padding_tiles_val = tile_idx_to_group_idx.size(
0) - num_non_exiting_tiles_val
assert num_non_exiting_tiles_val > 0
assert num_padding_tiles_val >= 0
num_tokens_per_expert = self.generate_num_tokens_per_expert(
num_tokens, approx_max_load=True)
token_selected_experts = self.generate_token_selected_experts(
num_tokens, num_tokens_per_expert)
tile_idx_to_group_idx = torch.tensor(
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
dtype=tile_idx_to_group_idx.dtype,
device=tile_idx_to_group_idx.device)
num_non_exiting_tiles = torch.tensor(
[num_non_exiting_tiles_val],
dtype=num_non_exiting_tiles.dtype,
device=num_non_exiting_tiles.device)
token_selected_experts = token_selected_experts.cuda()
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_experts,
top_k=self.top_k,
local_expert_offset=self.local_expert_offset,
local_num_experts=self.num_local_experts,
tile_tokens_dim=self.tile_size,
)
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others
def inputs_pre_hook_finalize_fusion(
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
num_tokens = self.infer_num_tokens(a.size(0))
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
num_tokens_per_expert)
tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit(
num_tokens_per_expert)
permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx(
num_tokens_per_expert = self.generate_num_tokens_per_expert(
num_tokens, approx_max_load=True)
token_selected_experts = self.generate_token_selected_experts(
num_tokens, num_tokens_per_expert)
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
num_padding_tiles_val = tile_idx_to_group_idx.size(
0) - num_non_exiting_tiles_val
assert num_non_exiting_tiles_val > 0
assert num_padding_tiles_val >= 0
assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val
assert len(permuted_idx_to_expanded_idx_list
) == num_non_exiting_tiles_val * self.tile_size
tile_idx_to_group_idx = torch.tensor(
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
dtype=tile_idx_to_group_idx.dtype,
device=tile_idx_to_group_idx.device)
tile_idx_to_mn_limit = torch.tensor(
tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val,
dtype=tile_idx_to_mn_limit.dtype,
device=tile_idx_to_mn_limit.device)
permuted_idx_to_expanded_idx = torch.tensor(
permuted_idx_to_expanded_idx_list + [self.pad_val] *
(num_padding_tiles_val * self.tile_size),
dtype=permuted_idx_to_expanded_idx.dtype,
device=permuted_idx_to_expanded_idx.device)
num_non_exiting_tiles = torch.tensor(
[num_non_exiting_tiles_val],
dtype=num_non_exiting_tiles.dtype,
device=num_non_exiting_tiles.device)
token_selected_experts = token_selected_experts.cuda()
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_experts,
top_k=self.top_k,
local_expert_offset=self.local_expert_offset,
local_num_experts=self.num_local_experts,
tile_tokens_dim=self.tile_size,
)
return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
@ -221,44 +244,6 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper):
IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7
IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX
def generate_permuted_idx_to_expanded_idx(
self, num_tokens: int, num_tokens_per_expert: List[int],
max_num_permuted_tokens: int) -> List[int]:
"""Generate permuted_idx_to_expanded_idx for gather operation.
Maps permuted index to expanded index (token_idx * top_k + topk_idx).
Args:
num_tokens: Total number of input tokens
num_tokens_per_expert: List of token counts per expert
max_num_permuted_tokens: Target size of the output list
Returns:
List of expanded IDs with length = max_num_permuted_tokens,
where permuted_idx_to_expanded_idx[permuted_idx] = expanded_idx
Padding tokens are marked with pad_val
Note: In kernel, use expanded_idx // top_k to get original token_idx
"""
permuted_idx_to_expanded_idx = []
colmajor_expanded_idx = 0
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
curr_num_tiles = (curr_num_tokens + self.tile_size -
1) // self.tile_size
for j in range(curr_num_tiles * self.tile_size):
if j < curr_num_tokens:
token_idx = colmajor_expanded_idx % num_tokens
topk_idx = colmajor_expanded_idx // num_tokens
expanded_idx = token_idx * self.top_k + topk_idx
permuted_idx_to_expanded_idx.append(expanded_idx)
colmajor_expanded_idx += 1
else:
permuted_idx_to_expanded_idx.append(
self.pad_val) # Padding token
# Pad to max_num_permuted_tokens
while len(permuted_idx_to_expanded_idx) < max_num_permuted_tokens:
permuted_idx_to_expanded_idx.append(self.pad_val)
return permuted_idx_to_expanded_idx
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
"""Pre-hook for gather-based SwiGLU fusion kernel.
@ -276,39 +261,31 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper):
IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
max_num_tiles = max_num_permuted_tokens // self.tile_size
num_tokens = self.infer_num_tokens(max_num_permuted_tokens)
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
num_tokens_per_expert)
tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit(
num_tokens_per_expert)
permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx(
num_tokens, num_tokens_per_expert, max_num_permuted_tokens)
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
num_padding_tiles_val = max_num_tiles - num_non_exiting_tiles_val
assert num_non_exiting_tiles_val > 0
assert num_padding_tiles_val >= 0
assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val
assert len(permuted_idx_to_expanded_idx_list) == max_num_permuted_tokens
num_tokens_per_expert = self.generate_num_tokens_per_expert(
num_tokens, approx_max_load=True)
token_selected_experts = self.generate_token_selected_experts(
num_tokens, num_tokens_per_expert)
tile_idx_to_group_idx = torch.tensor(
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
dtype=tile_idx_to_group_idx.dtype,
device=tile_idx_to_group_idx.device)
tile_idx_to_mn_limit = torch.tensor(
tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val,
dtype=tile_idx_to_mn_limit.dtype,
device=tile_idx_to_mn_limit.device)
permuted_idx_to_expanded_idx = torch.tensor(
permuted_idx_to_expanded_idx_list,
dtype=permuted_idx_to_expanded_idx.dtype,
device=permuted_idx_to_expanded_idx.device)
num_non_exiting_tiles = torch.tensor(
[num_non_exiting_tiles_val],
dtype=num_non_exiting_tiles.dtype,
device=num_non_exiting_tiles.device)
token_selected_experts = token_selected_experts.cuda()
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_experts,
top_k=self.top_k,
local_expert_offset=self.local_expert_offset,
local_num_experts=self.num_local_experts,
tile_tokens_dim=self.tile_size,
)
return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx,
tile_idx_to_mn_limit, permuted_idx_to_expanded_idx,
num_non_exiting_tiles, global_sf)
@ -924,6 +901,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
5, 0,
helper.infer_shape_max_num_tiles)),
inputs_pre_hook=helper.inputs_pre_hook,
use_cold_l2_cache=True,
)
return self.__class__.tuning_config_cache[key]
@ -1222,6 +1200,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
8, 0, helper.infer_shape_max_num_permuted_tokens),
ConstraintSpec(10, 0, helper.infer_shape_num_tokens)),
inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion,
use_cold_l2_cache=True,
)
return self.__class__.tuning_config_cache[key]
@ -1602,6 +1581,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
5, 0,
helper.infer_shape_max_num_tiles)),
inputs_pre_hook=helper.inputs_pre_hook,
use_cold_l2_cache=True,
)
return self.__class__.tuning_config_cache[key]
@ -1931,6 +1911,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
6, 0,
helper.infer_shape_max_num_tiles)),
inputs_pre_hook=helper.inputs_pre_hook,
use_cold_l2_cache=True,
)
return self.__class__.tuning_config_cache[key]

View File

@ -175,7 +175,8 @@ class CuteDslFusedMoENvfp4InputsHelper(GroupedGemmInputsHelper):
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
x, token_selected_experts, *others = inputs
num_tokens = token_selected_experts.size(0)
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
num_tokens_per_expert = self.generate_num_tokens_per_expert(
num_tokens, approx_max_load=True)
new_token_selected_experts = []
for i, curr_num_tokens in enumerate(num_tokens_per_expert,
@ -258,6 +259,7 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
ConstraintSpec(
4, 0, helper.infer_shape_num_tokens)),
inputs_pre_hook=helper.inputs_pre_hook,
use_cold_l2_cache=True,
)
return self.__class__.tuning_config_cache[key]

View File

@ -1,6 +1,7 @@
import contextlib
import functools
import itertools
import os
import unittest.mock
import weakref
from enum import IntEnum
@ -11,6 +12,7 @@ import torch
import tensorrt_llm._torch.model_config
import tensorrt_llm.bindings
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward
@ -54,8 +56,15 @@ def round_up(a, b):
return ceil_div(a, b) * b
def get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
def get_balanced_selection_impl_default(
num_tokens: int,
top_k: int,
num_experts: int,
dtype: torch.dtype,
device: torch.device,
dp_size: int,
dp_rank: int,
ep_size: int,
):
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
num_tokens, top_k
@ -68,6 +77,32 @@ def get_balanced_selection_no_cache(
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
def get_balanced_selection_impl_random(
num_tokens: int,
top_k: int,
num_experts: int,
dtype: torch.dtype,
device: torch.device,
dp_size: int,
dp_rank: int,
ep_size: int,
):
helper = GroupedGemmInputsHelper(num_experts, top_k, num_experts, 0, 128)
num_tokens_per_expert = helper.generate_num_tokens_per_expert(num_tokens, approx_max_load=False)
assert sum(num_tokens_per_expert) == num_tokens * top_k
token_selected_experts = helper.generate_token_selected_experts(
num_tokens, num_tokens_per_expert
)
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
def get_balanced_selection_no_cache(*args, **kwargs):
if os.environ.get("TRTLLM_LAYERWISE_BENCHMARK_BALANCED_IMPL", "DEFAULT") == "RANDOM":
return get_balanced_selection_impl_random(*args, **kwargs)
else:
return get_balanced_selection_impl_default(*args, **kwargs)
get_balanced_selection = functools.cache(get_balanced_selection_no_cache)

View File

@ -0,0 +1,11 @@
# Launch Scripts for CuTe DSL Kernels
## MoE Workload Generator
```bash
# Generate workload using a balanced random method
# Per-rank token number 128, EP size 32 (a typical workload for large EP gen phase)
python moe_workload_generator.py --num_tokens 128 --ep_size 32 --tile_size 128
# Per-rank token number 8192, EP size 4 (a typical workload for ctx phase)
python moe_workload_generator.py --num_tokens 8192 --ep_size 4 --tile_size 256
```

View File

@ -0,0 +1,176 @@
import json
import os
from typing import List, Optional
import click
import safetensors.torch
import torch
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
from tensorrt_llm.tools.layer_wise_benchmarks.runner import (
get_balanced_selection_impl_default,
get_balanced_selection_impl_random,
)
def get_balanced_selection_impl_default_legacy(
num_tokens: int,
top_k: int,
num_experts: int,
dtype: torch.dtype,
device: torch.device,
dp_size: int,
dp_rank: int,
ep_size: int,
):
world_size = ep_size
rank = dp_rank
# First, each sender selects target rank
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view(
num_tokens, world_size, top_k
)
target_rank_before_mod += top_k * torch.arange(num_tokens).view(
num_tokens, 1, 1
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic
target_rank = target_rank_before_mod % world_size
# Second, each receiver selects target expert
target_expert = torch.empty_like(target_rank)
for reciever_rank in range(world_size):
mask = target_rank == reciever_rank
experts_per_rank = num_experts // world_size
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
def gen_moe_workload(
num_tokens: int,
top_k: int,
num_experts: int,
ep_size: int,
num_tokens_per_expert: Optional[List[int]],
tile_size: int,
method: str = "balanced_random",
):
if num_tokens_per_expert is not None:
num_local_experts = len(num_tokens_per_expert)
assert num_local_experts * ep_size == num_experts
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
token_selected_experts = helper.generate_token_selected_experts(
num_tokens * ep_size, num_tokens_per_expert
)
token_selected_experts = token_selected_experts.cuda()
else:
if method == "balanced_random":
get_balanced_selection_impl = get_balanced_selection_impl_random
elif method == "balanced_default":
get_balanced_selection_impl = get_balanced_selection_impl_default
elif method == "balanced_default_legacy":
get_balanced_selection_impl = get_balanced_selection_impl_default_legacy
else:
raise ValueError(f"Invalid method: {method}.")
token_selected_experts = [
get_balanced_selection_impl(
num_tokens=num_tokens,
top_k=top_k,
num_experts=num_experts,
dtype=torch.int32,
device="cuda",
dp_size=ep_size,
dp_rank=i,
ep_size=ep_size,
)
for i in range(ep_size)
]
token_selected_experts = torch.cat(token_selected_experts, dim=0)
assert token_selected_experts.size() == (num_tokens * ep_size, top_k)
token_final_scales = torch.ones_like(token_selected_experts, dtype=torch.float32)
return torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=num_experts,
top_k=top_k,
local_expert_offset=0,
local_num_experts=num_experts // ep_size,
tile_tokens_dim=tile_size,
)
@click.command("moe_workload_generator")
@click.option("--num_tokens", type=int, default=128)
@click.option("--top_k", type=int, default=8)
@click.option("--num_experts", type=int, default=256)
@click.option("--ep_size", type=int, default=32)
@click.option("--num_tokens_per_expert", type=str, default=None)
@click.option("--tile_size", type=click.Choice([128, 256]), default=128)
@click.option(
"--method",
type=click.Choice(["balanced_random", "balanced_default", "balanced_default_legacy"]),
default="balanced_random",
)
@click.option("--seed", type=int, default=515)
@click.option("--output_path", type=str, default="./moe_workload")
def main(
num_tokens: int,
top_k: int,
num_experts: int,
ep_size: int,
num_tokens_per_expert: str,
tile_size: int,
method: str,
seed: int,
output_path: str,
):
torch.manual_seed(seed)
if num_tokens_per_expert is not None:
num_tokens_per_expert = [int(x) for x in num_tokens_per_expert.split(",")]
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = gen_moe_workload(
num_tokens=num_tokens,
top_k=top_k,
num_experts=num_experts,
ep_size=ep_size,
num_tokens_per_expert=num_tokens_per_expert,
tile_size=tile_size,
method=method,
)
if not os.path.isdir(output_path):
os.makedirs(output_path)
metadata = {
"num_tokens": num_tokens,
"top_k": top_k,
"num_experts": num_experts,
"ep_size": ep_size,
"tile_size": tile_size,
"method": method,
"seed": seed,
}
with open(f"{output_path}/metadata.json", "w") as f:
json.dump(metadata, f)
workload = {
"tile_idx_to_group_idx": tile_idx_to_group_idx,
"tile_idx_to_mn_limit": tile_idx_to_mn_limit,
"expanded_idx_to_permuted_idx": expanded_idx_to_permuted_idx,
"permuted_idx_to_expanded_idx": permuted_idx_to_expanded_idx,
"total_num_padded_tokens": total_num_padded_tokens,
"num_non_exiting_tiles": num_non_exiting_tiles,
}
safetensors.torch.save_file(workload, f"{output_path}/workload.safetensors")
if __name__ == "__main__":
main()

View File

@ -2,10 +2,7 @@ import pytest
import torch
from utils.util import check_accuracy
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import (
GatherGroupedGemmInputsHelper,
GroupedGemmInputsHelper,
)
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref
from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate
from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf
@ -740,55 +737,29 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
# Generate routing information
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
_, token_selected_experts = routing_logits.topk(top_k, dim=-1)
token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1)
token_selected_experts = token_selected_experts.to(torch.int32)
num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts)
num_tokens_per_expert = num_tokens_per_expert[:num_local_experts]
# Ensure at least one valid token
if num_tokens_per_expert.sum().item() == 0:
num_tokens_per_expert[0] = 1
num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size
num_tokens_per_expert = num_tokens_per_expert.cpu()
num_tiles_per_expert = num_tiles_per_expert.cpu()
num_valid_tiles = num_tiles_per_expert.sum().item()
num_valid_permuted_tokens = num_valid_tiles * tile_size
token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32)
# Create helper
helper = GatherGroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
max_num_tiles = helper.get_max_num_tiles(num_tokens)
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
assert 0 <= num_valid_tiles <= max_num_tiles
assert 0 <= num_valid_permuted_tokens <= max_num_permuted_tokens
# Generate tile metadata
num_non_exiting_tiles = torch.tensor([num_valid_tiles], dtype=torch.int32, device="cuda")
tile_idx_to_group_idx = torch.empty(max_num_tiles, dtype=torch.int32)
tile_idx_to_mn_limit = torch.empty(max_num_tiles, dtype=torch.int32)
tile_idx_to_group_idx.fill_(int(-2e9))
tile_idx_to_mn_limit.fill_(int(-2e9))
tile_idx_to_group_idx_list = helper.generate_tile_idx_to_group_idx(
num_tokens_per_expert.tolist()
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=num_experts,
top_k=top_k,
local_expert_offset=0,
local_num_experts=num_local_experts,
tile_tokens_dim=tile_size,
)
tile_idx_to_mn_limit_list = helper.generate_tile_idx_to_mn_limit(num_tokens_per_expert.tolist())
for idx, (group_idx, mn_limit) in enumerate(
zip(tile_idx_to_group_idx_list, tile_idx_to_mn_limit_list)
):
tile_idx_to_group_idx[idx] = group_idx
tile_idx_to_mn_limit[idx] = mn_limit
tile_idx_to_group_idx = tile_idx_to_group_idx.cuda()
tile_idx_to_mn_limit = tile_idx_to_mn_limit.cuda()
# Generate permuted_idx_to_expanded_idx for gather operation
permuted_idx_to_expanded_idx_list = helper.generate_permuted_idx_to_expanded_idx(
num_tokens, num_tokens_per_expert.tolist(), max_num_permuted_tokens
)
permuted_idx_to_expanded_idx = torch.tensor(
permuted_idx_to_expanded_idx_list, dtype=torch.int32, device="cuda"
)
assert permuted_idx_to_expanded_idx.size(0) == max_num_permuted_tokens
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
num_valid_permuted_tokens = total_num_padded_tokens.item()
# Create input tensors (original size, not permuted)
a = torch.randint(-5, 5, (num_tokens, hidden_size), dtype=torch.int32, device="cuda").to(
@ -826,18 +797,22 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
)
# Compute reference: manually gather, compute GEMM, apply SwiGLU, then quantize
a_gathered = torch.empty(
max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype, device=a.device
)
permuted_idx_to_expanded_idx_list = permuted_idx_to_expanded_idx.cpu().tolist()
tile_idx_to_mn_limit_list = tile_idx_to_mn_limit.cpu().tolist()
a_gathered = torch.empty(max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype)
a_sf_gathered = torch.empty(
max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype, device=a_sf.device
max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype
)
for i in range(num_valid_permuted_tokens):
expanded_idx = permuted_idx_to_expanded_idx[i].item()
if expanded_idx != helper.pad_val:
token_id = expanded_idx // top_k
a_gathered[i] = a[token_id]
a_sf_gathered[i] = a_sf_unswizzled[token_id]
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
continue
expanded_idx = permuted_idx_to_expanded_idx_list[i]
token_id = expanded_idx // top_k
a_gathered[i] = a[token_id]
a_sf_gathered[i] = a_sf_unswizzled[token_id]
a_gathered = a_gathered.to(a.device)
a_sf_gathered = a_sf_gathered.to(a.device)
# Swizzle a_sf_gathered for reference GEMM
a_sf_gathered_swizzled = swizzle_sf(
@ -886,8 +861,9 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
# Create mask for valid tokens
valid_token_mask = torch.zeros(num_valid_permuted_tokens, dtype=torch.bool, device="cuda")
for i in range(num_valid_permuted_tokens):
if permuted_idx_to_expanded_idx[i].item() != helper.pad_val:
valid_token_mask[i] = True
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
continue
valid_token_mask[i] = True
num_valid_tokens = valid_token_mask.sum().item()
if num_valid_tokens > 0:
@ -905,9 +881,10 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
c_sf_valid = []
c_sf_ref_valid = []
for i in range(num_valid_permuted_tokens):
if permuted_idx_to_expanded_idx[i].item() != helper.pad_val:
c_sf_valid.append(c_sf_unswizzled[i])
c_sf_ref_valid.append(c_sf_ref_unswizzled[i])
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
continue
c_sf_valid.append(c_sf_unswizzled[i])
c_sf_ref_valid.append(c_sf_ref_unswizzled[i])
c_sf_valid = torch.cat(c_sf_valid)
c_sf_ref_valid = torch.cat(c_sf_ref_valid)