diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 771e7ed7c8..ff37aa7d95 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -1,4 +1,5 @@ import itertools +import math from typing import List, Optional, Tuple import torch @@ -31,13 +32,19 @@ class GroupedGemmInputsHelper: IDX_A = 0 IDX_SHAPE_INFER = IDX_A # Default: use a tensor for shape inference - def __init__(self, num_experts: int, top_k: int, num_local_experts: int, - local_expert_offset: int, tile_size: int): + def __init__(self, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + seed: int = 515): self.num_experts = num_experts self.top_k = top_k self.num_local_experts = num_local_experts self.local_expert_offset = local_expert_offset self.tile_size = tile_size + self.seed = seed # Padding values should never be accessed. # Intentionally use a large padding value to expose issues early. self.pad_val = int(2e9) @@ -82,118 +89,134 @@ class GroupedGemmInputsHelper: self, input_shapes: List[torch.Size]) -> int: return self.infer_shape_max_num_tiles(input_shapes) * self.tile_size - def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]: - average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts - balance = 0 - num_tokens_per_expert = [] - for i in range(self.num_local_experts): - balance += average_num_tokens_per_expert - if balance <= 1e-3: - continue - curr_num_tokens = int(balance) + 1 - num_tokens_per_expert.append(curr_num_tokens) - balance -= curr_num_tokens + def generate_num_tokens_per_expert(self, + num_tokens: int, + approx_max_load: bool = False + ) -> List[int]: + ep_size = self.num_experts // self.num_local_experts + average_num_tokens_per_rank = num_tokens * self.top_k / ep_size + + if approx_max_load: + # https://en.wikipedia.org/wiki/Balls_into_bins_problem + # The constant c can be measured empirically, we choose 1.0 for simplicity. + c = 1.0 + extra_num_tokens_on_curr_rank = c * math.sqrt( + average_num_tokens_per_rank * math.log(ep_size)) + num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank + + extra_num_tokens_on_curr_rank) + else: + num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank) + + num_tokens_on_curr_rank = min(num_tokens * self.top_k, + num_tokens_on_curr_rank) + + base, remainder = divmod(num_tokens_on_curr_rank, + self.num_local_experts) + num_tokens_per_expert = [base + 1] * remainder + [base] * ( + self.num_local_experts - remainder) + assert len(num_tokens_per_expert) == self.num_local_experts + assert sum(num_tokens_per_expert) == num_tokens_on_curr_rank return num_tokens_per_expert - def generate_tile_idx_to_group_idx( - self, num_tokens_per_expert: List[int]) -> List[int]: - tile_idx_to_group_idx = [] - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - tile_idx_to_group_idx.extend([i] * curr_num_tiles) - return tile_idx_to_group_idx - - def generate_tile_idx_to_mn_limit( - self, num_tokens_per_expert: List[int]) -> List[int]: - tile_idx_to_mn_limit = [] - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - prev_mn_limit = len(tile_idx_to_mn_limit) * self.tile_size - for j in range(curr_num_tiles): - tile_idx_to_mn_limit.append(prev_mn_limit + min( - (j + 1) * self.tile_size, curr_num_tokens)) - return tile_idx_to_mn_limit - - def generate_permuted_idx_to_expanded_idx( + def generate_token_selected_experts( self, num_tokens: int, - num_tokens_per_expert: List[int]) -> List[int]: - permuted_idx_to_expanded_idx = [] - colmajor_expanded_idx = 0 - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - for j in range(curr_num_tiles * self.tile_size): - if j < curr_num_tokens: - token_idx = colmajor_expanded_idx % num_tokens - topk_idx = colmajor_expanded_idx // num_tokens - expanded_idx = token_idx * self.top_k + topk_idx - permuted_idx_to_expanded_idx.append(expanded_idx) - colmajor_expanded_idx += 1 - else: - permuted_idx_to_expanded_idx.append(self.pad_val) - return permuted_idx_to_expanded_idx + num_tokens_per_expert: List[int]) -> torch.Tensor: + """Balanced random based on rejection sampling. + """ + token_selected_experts = -torch.ones( + num_tokens, self.top_k, dtype=torch.int32) + num_selected_experts = torch.zeros(num_tokens, dtype=torch.int32) + + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(self.seed) + selection_orders = [ + torch.randperm(num_tokens) + for _ in range(self.num_local_experts) + ] + + for j, num_tokens_j in enumerate(num_tokens_per_expert): + selection_order_j = selection_orders[j].tolist() + prioritized = torch.nonzero(num_selected_experts <= ( + self.top_k - (self.num_experts - j))).squeeze(-1).tolist() + if len(prioritized) > 0: + selection_order_j = prioritized + [ + i for i in selection_order_j if i not in prioritized + ] + for i in selection_order_j: + if num_selected_experts[i] < self.top_k: + token_selected_experts[ + i, + num_selected_experts[i]] = j + self.local_expert_offset + num_selected_experts[i] += 1 + num_tokens_j -= 1 + if num_tokens_j <= 0: + break + + assert ((token_selected_experts + >= 0).sum(dim=-1) == num_selected_experts).all().item() + if self.num_local_experts == self.num_experts: + assert (num_selected_experts == self.top_k).all().item() + else: + assert (num_selected_experts <= self.top_k).all().item() + return token_selected_experts def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs num_tokens = self.infer_num_tokens(a.size(0)) - num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) - tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( - num_tokens_per_expert) - num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) - num_padding_tiles_val = tile_idx_to_group_idx.size( - 0) - num_non_exiting_tiles_val - assert num_non_exiting_tiles_val > 0 - assert num_padding_tiles_val >= 0 + num_tokens_per_expert = self.generate_num_tokens_per_expert( + num_tokens, approx_max_load=True) + token_selected_experts = self.generate_token_selected_experts( + num_tokens, num_tokens_per_expert) - tile_idx_to_group_idx = torch.tensor( - tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_group_idx.dtype, - device=tile_idx_to_group_idx.device) - num_non_exiting_tiles = torch.tensor( - [num_non_exiting_tiles_val], - dtype=num_non_exiting_tiles.dtype, - device=num_non_exiting_tiles.device) + token_selected_experts = token_selected_experts.cuda() + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=self.num_experts, + top_k=self.top_k, + local_expert_offset=self.local_expert_offset, + local_num_experts=self.num_local_experts, + tile_tokens_dim=self.tile_size, + ) return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others def inputs_pre_hook_finalize_fusion( self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs num_tokens = self.infer_num_tokens(a.size(0)) - num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) - tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( - num_tokens_per_expert) - tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit( - num_tokens_per_expert) - permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx( + num_tokens_per_expert = self.generate_num_tokens_per_expert( + num_tokens, approx_max_load=True) + token_selected_experts = self.generate_token_selected_experts( num_tokens, num_tokens_per_expert) - num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) - num_padding_tiles_val = tile_idx_to_group_idx.size( - 0) - num_non_exiting_tiles_val - assert num_non_exiting_tiles_val > 0 - assert num_padding_tiles_val >= 0 - assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val - assert len(permuted_idx_to_expanded_idx_list - ) == num_non_exiting_tiles_val * self.tile_size - tile_idx_to_group_idx = torch.tensor( - tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_group_idx.dtype, - device=tile_idx_to_group_idx.device) - tile_idx_to_mn_limit = torch.tensor( - tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_mn_limit.dtype, - device=tile_idx_to_mn_limit.device) - permuted_idx_to_expanded_idx = torch.tensor( - permuted_idx_to_expanded_idx_list + [self.pad_val] * - (num_padding_tiles_val * self.tile_size), - dtype=permuted_idx_to_expanded_idx.dtype, - device=permuted_idx_to_expanded_idx.device) - num_non_exiting_tiles = torch.tensor( - [num_non_exiting_tiles_val], - dtype=num_non_exiting_tiles.dtype, - device=num_non_exiting_tiles.device) + token_selected_experts = token_selected_experts.cuda() + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=self.num_experts, + top_k=self.top_k, + local_expert_offset=self.local_expert_offset, + local_num_experts=self.num_local_experts, + tile_tokens_dim=self.tile_size, + ) return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales @@ -221,44 +244,6 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7 IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX - def generate_permuted_idx_to_expanded_idx( - self, num_tokens: int, num_tokens_per_expert: List[int], - max_num_permuted_tokens: int) -> List[int]: - """Generate permuted_idx_to_expanded_idx for gather operation. - - Maps permuted index to expanded index (token_idx * top_k + topk_idx). - - Args: - num_tokens: Total number of input tokens - num_tokens_per_expert: List of token counts per expert - max_num_permuted_tokens: Target size of the output list - - Returns: - List of expanded IDs with length = max_num_permuted_tokens, - where permuted_idx_to_expanded_idx[permuted_idx] = expanded_idx - Padding tokens are marked with pad_val - Note: In kernel, use expanded_idx // top_k to get original token_idx - """ - permuted_idx_to_expanded_idx = [] - colmajor_expanded_idx = 0 - for i, curr_num_tokens in enumerate(num_tokens_per_expert): - curr_num_tiles = (curr_num_tokens + self.tile_size - - 1) // self.tile_size - for j in range(curr_num_tiles * self.tile_size): - if j < curr_num_tokens: - token_idx = colmajor_expanded_idx % num_tokens - topk_idx = colmajor_expanded_idx // num_tokens - expanded_idx = token_idx * self.top_k + topk_idx - permuted_idx_to_expanded_idx.append(expanded_idx) - colmajor_expanded_idx += 1 - else: - permuted_idx_to_expanded_idx.append( - self.pad_val) # Padding token - # Pad to max_num_permuted_tokens - while len(permuted_idx_to_expanded_idx) < max_num_permuted_tokens: - permuted_idx_to_expanded_idx.append(self.pad_val) - return permuted_idx_to_expanded_idx - def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: """Pre-hook for gather-based SwiGLU fusion kernel. @@ -276,39 +261,31 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0) - max_num_tiles = max_num_permuted_tokens // self.tile_size - num_tokens = self.infer_num_tokens(max_num_permuted_tokens) - num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) - tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( - num_tokens_per_expert) - tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit( - num_tokens_per_expert) - permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx( - num_tokens, num_tokens_per_expert, max_num_permuted_tokens) - num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) - num_padding_tiles_val = max_num_tiles - num_non_exiting_tiles_val - assert num_non_exiting_tiles_val > 0 - assert num_padding_tiles_val >= 0 - assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val - assert len(permuted_idx_to_expanded_idx_list) == max_num_permuted_tokens + num_tokens_per_expert = self.generate_num_tokens_per_expert( + num_tokens, approx_max_load=True) + token_selected_experts = self.generate_token_selected_experts( + num_tokens, num_tokens_per_expert) - tile_idx_to_group_idx = torch.tensor( - tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_group_idx.dtype, - device=tile_idx_to_group_idx.device) - tile_idx_to_mn_limit = torch.tensor( - tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val, - dtype=tile_idx_to_mn_limit.dtype, - device=tile_idx_to_mn_limit.device) - permuted_idx_to_expanded_idx = torch.tensor( - permuted_idx_to_expanded_idx_list, - dtype=permuted_idx_to_expanded_idx.dtype, - device=permuted_idx_to_expanded_idx.device) - num_non_exiting_tiles = torch.tensor( - [num_non_exiting_tiles_val], - dtype=num_non_exiting_tiles.dtype, - device=num_non_exiting_tiles.device) + token_selected_experts = token_selected_experts.cuda() + token_final_scales = torch.ones_like(token_selected_experts, + dtype=torch.float32) + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=self.num_experts, + top_k=self.top_k, + local_expert_offset=self.local_expert_offset, + local_num_experts=self.num_local_experts, + tile_tokens_dim=self.tile_size, + ) return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf) @@ -924,6 +901,7 @@ if IS_CUTLASS_DSL_AVAILABLE: 5, 0, helper.infer_shape_max_num_tiles)), inputs_pre_hook=helper.inputs_pre_hook, + use_cold_l2_cache=True, ) return self.__class__.tuning_config_cache[key] @@ -1222,6 +1200,7 @@ if IS_CUTLASS_DSL_AVAILABLE: 8, 0, helper.infer_shape_max_num_permuted_tokens), ConstraintSpec(10, 0, helper.infer_shape_num_tokens)), inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion, + use_cold_l2_cache=True, ) return self.__class__.tuning_config_cache[key] @@ -1602,6 +1581,7 @@ if IS_CUTLASS_DSL_AVAILABLE: 5, 0, helper.infer_shape_max_num_tiles)), inputs_pre_hook=helper.inputs_pre_hook, + use_cold_l2_cache=True, ) return self.__class__.tuning_config_cache[key] @@ -1931,6 +1911,7 @@ if IS_CUTLASS_DSL_AVAILABLE: 6, 0, helper.infer_shape_max_num_tiles)), inputs_pre_hook=helper.inputs_pre_hook, + use_cold_l2_cache=True, ) return self.__class__.tuning_config_cache[key] diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 2c9ebfacce..7e434e39d3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -175,7 +175,8 @@ class CuteDslFusedMoENvfp4InputsHelper(GroupedGemmInputsHelper): def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: x, token_selected_experts, *others = inputs num_tokens = token_selected_experts.size(0) - num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) + num_tokens_per_expert = self.generate_num_tokens_per_expert( + num_tokens, approx_max_load=True) new_token_selected_experts = [] for i, curr_num_tokens in enumerate(num_tokens_per_expert, @@ -258,6 +259,7 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner): ConstraintSpec( 4, 0, helper.infer_shape_num_tokens)), inputs_pre_hook=helper.inputs_pre_hook, + use_cold_l2_cache=True, ) return self.__class__.tuning_config_cache[key] diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py index 4b6f9050e6..637a75ef35 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner.py @@ -1,6 +1,7 @@ import contextlib import functools import itertools +import os import unittest.mock import weakref from enum import IntEnum @@ -11,6 +12,7 @@ import torch import tensorrt_llm._torch.model_config import tensorrt_llm.bindings from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward @@ -54,8 +56,15 @@ def round_up(a, b): return ceil_div(a, b) * b -def get_balanced_selection_no_cache( - num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size +def get_balanced_selection_impl_default( + num_tokens: int, + top_k: int, + num_experts: int, + dtype: torch.dtype, + device: torch.device, + dp_size: int, + dp_rank: int, + ep_size: int, ): token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view( num_tokens, top_k @@ -68,6 +77,32 @@ def get_balanced_selection_no_cache( return token_selected_experts.contiguous().to(dtype=dtype, device=device) +def get_balanced_selection_impl_random( + num_tokens: int, + top_k: int, + num_experts: int, + dtype: torch.dtype, + device: torch.device, + dp_size: int, + dp_rank: int, + ep_size: int, +): + helper = GroupedGemmInputsHelper(num_experts, top_k, num_experts, 0, 128) + num_tokens_per_expert = helper.generate_num_tokens_per_expert(num_tokens, approx_max_load=False) + assert sum(num_tokens_per_expert) == num_tokens * top_k + token_selected_experts = helper.generate_token_selected_experts( + num_tokens, num_tokens_per_expert + ) + return token_selected_experts.contiguous().to(dtype=dtype, device=device) + + +def get_balanced_selection_no_cache(*args, **kwargs): + if os.environ.get("TRTLLM_LAYERWISE_BENCHMARK_BALANCED_IMPL", "DEFAULT") == "RANDOM": + return get_balanced_selection_impl_random(*args, **kwargs) + else: + return get_balanced_selection_impl_default(*args, **kwargs) + + get_balanced_selection = functools.cache(get_balanced_selection_no_cache) diff --git a/tests/scripts/cute_dsl_kernels/README.md b/tests/scripts/cute_dsl_kernels/README.md new file mode 100644 index 0000000000..b8ed94ae04 --- /dev/null +++ b/tests/scripts/cute_dsl_kernels/README.md @@ -0,0 +1,11 @@ +# Launch Scripts for CuTe DSL Kernels + +## MoE Workload Generator + +```bash +# Generate workload using a balanced random method +# Per-rank token number 128, EP size 32 (a typical workload for large EP gen phase) +python moe_workload_generator.py --num_tokens 128 --ep_size 32 --tile_size 128 +# Per-rank token number 8192, EP size 4 (a typical workload for ctx phase) +python moe_workload_generator.py --num_tokens 8192 --ep_size 4 --tile_size 256 +``` diff --git a/tests/scripts/cute_dsl_kernels/moe_workload_generator.py b/tests/scripts/cute_dsl_kernels/moe_workload_generator.py new file mode 100644 index 0000000000..09ec52fafe --- /dev/null +++ b/tests/scripts/cute_dsl_kernels/moe_workload_generator.py @@ -0,0 +1,176 @@ +import json +import os +from typing import List, Optional + +import click +import safetensors.torch +import torch + +from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper +from tensorrt_llm.tools.layer_wise_benchmarks.runner import ( + get_balanced_selection_impl_default, + get_balanced_selection_impl_random, +) + + +def get_balanced_selection_impl_default_legacy( + num_tokens: int, + top_k: int, + num_experts: int, + dtype: torch.dtype, + device: torch.device, + dp_size: int, + dp_rank: int, + ep_size: int, +): + world_size = ep_size + rank = dp_rank + # First, each sender selects target rank + target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view( + num_tokens, world_size, top_k + ) + target_rank_before_mod += top_k * torch.arange(num_tokens).view( + num_tokens, 1, 1 + ) # Shift `top_k` ranks for the next token on each rank, to balance network traffic + target_rank = target_rank_before_mod % world_size + # Second, each receiver selects target expert + target_expert = torch.empty_like(target_rank) + for reciever_rank in range(world_size): + mask = target_rank == reciever_rank + experts_per_rank = num_experts // world_size + local_expert = torch.arange(num_tokens * top_k) % experts_per_rank + target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert + token_selected_experts = target_expert[:, rank].sort(dim=-1).values + return token_selected_experts.contiguous().to(dtype=dtype, device=device) + + +def gen_moe_workload( + num_tokens: int, + top_k: int, + num_experts: int, + ep_size: int, + num_tokens_per_expert: Optional[List[int]], + tile_size: int, + method: str = "balanced_random", +): + if num_tokens_per_expert is not None: + num_local_experts = len(num_tokens_per_expert) + assert num_local_experts * ep_size == num_experts + helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size) + token_selected_experts = helper.generate_token_selected_experts( + num_tokens * ep_size, num_tokens_per_expert + ) + token_selected_experts = token_selected_experts.cuda() + else: + if method == "balanced_random": + get_balanced_selection_impl = get_balanced_selection_impl_random + elif method == "balanced_default": + get_balanced_selection_impl = get_balanced_selection_impl_default + elif method == "balanced_default_legacy": + get_balanced_selection_impl = get_balanced_selection_impl_default_legacy + else: + raise ValueError(f"Invalid method: {method}.") + + token_selected_experts = [ + get_balanced_selection_impl( + num_tokens=num_tokens, + top_k=top_k, + num_experts=num_experts, + dtype=torch.int32, + device="cuda", + dp_size=ep_size, + dp_rank=i, + ep_size=ep_size, + ) + for i in range(ep_size) + ] + token_selected_experts = torch.cat(token_selected_experts, dim=0) + + assert token_selected_experts.size() == (num_tokens * ep_size, top_k) + token_final_scales = torch.ones_like(token_selected_experts, dtype=torch.float32) + return torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=num_experts, + top_k=top_k, + local_expert_offset=0, + local_num_experts=num_experts // ep_size, + tile_tokens_dim=tile_size, + ) + + +@click.command("moe_workload_generator") +@click.option("--num_tokens", type=int, default=128) +@click.option("--top_k", type=int, default=8) +@click.option("--num_experts", type=int, default=256) +@click.option("--ep_size", type=int, default=32) +@click.option("--num_tokens_per_expert", type=str, default=None) +@click.option("--tile_size", type=click.Choice([128, 256]), default=128) +@click.option( + "--method", + type=click.Choice(["balanced_random", "balanced_default", "balanced_default_legacy"]), + default="balanced_random", +) +@click.option("--seed", type=int, default=515) +@click.option("--output_path", type=str, default="./moe_workload") +def main( + num_tokens: int, + top_k: int, + num_experts: int, + ep_size: int, + num_tokens_per_expert: str, + tile_size: int, + method: str, + seed: int, + output_path: str, +): + torch.manual_seed(seed) + + if num_tokens_per_expert is not None: + num_tokens_per_expert = [int(x) for x in num_tokens_per_expert.split(",")] + + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = gen_moe_workload( + num_tokens=num_tokens, + top_k=top_k, + num_experts=num_experts, + ep_size=ep_size, + num_tokens_per_expert=num_tokens_per_expert, + tile_size=tile_size, + method=method, + ) + + if not os.path.isdir(output_path): + os.makedirs(output_path) + + metadata = { + "num_tokens": num_tokens, + "top_k": top_k, + "num_experts": num_experts, + "ep_size": ep_size, + "tile_size": tile_size, + "method": method, + "seed": seed, + } + with open(f"{output_path}/metadata.json", "w") as f: + json.dump(metadata, f) + + workload = { + "tile_idx_to_group_idx": tile_idx_to_group_idx, + "tile_idx_to_mn_limit": tile_idx_to_mn_limit, + "expanded_idx_to_permuted_idx": expanded_idx_to_permuted_idx, + "permuted_idx_to_expanded_idx": permuted_idx_to_expanded_idx, + "total_num_padded_tokens": total_num_padded_tokens, + "num_non_exiting_tiles": num_non_exiting_tiles, + } + safetensors.torch.save_file(workload, f"{output_path}/workload.safetensors") + + +if __name__ == "__main__": + main() diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index c5445c21f4..b176c03685 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -2,10 +2,7 @@ import pytest import torch from utils.util import check_accuracy -from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import ( - GatherGroupedGemmInputsHelper, - GroupedGemmInputsHelper, -) +from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf @@ -740,55 +737,29 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( # Generate routing information routing_logits = torch.randn(num_tokens, num_experts, device="cuda") - _, token_selected_experts = routing_logits.topk(top_k, dim=-1) + token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1) token_selected_experts = token_selected_experts.to(torch.int32) - num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts) - num_tokens_per_expert = num_tokens_per_expert[:num_local_experts] - # Ensure at least one valid token - if num_tokens_per_expert.sum().item() == 0: - num_tokens_per_expert[0] = 1 - num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size - num_tokens_per_expert = num_tokens_per_expert.cpu() - num_tiles_per_expert = num_tiles_per_expert.cpu() - num_valid_tiles = num_tiles_per_expert.sum().item() - num_valid_permuted_tokens = num_valid_tiles * tile_size + token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32) - # Create helper - helper = GatherGroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size) - max_num_tiles = helper.get_max_num_tiles(num_tokens) - max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens) - assert 0 <= num_valid_tiles <= max_num_tiles - assert 0 <= num_valid_permuted_tokens <= max_num_permuted_tokens - - # Generate tile metadata - num_non_exiting_tiles = torch.tensor([num_valid_tiles], dtype=torch.int32, device="cuda") - tile_idx_to_group_idx = torch.empty(max_num_tiles, dtype=torch.int32) - tile_idx_to_mn_limit = torch.empty(max_num_tiles, dtype=torch.int32) - tile_idx_to_group_idx.fill_(int(-2e9)) - tile_idx_to_mn_limit.fill_(int(-2e9)) - - tile_idx_to_group_idx_list = helper.generate_tile_idx_to_group_idx( - num_tokens_per_expert.tolist() + ( + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx, + total_num_padded_tokens, + num_non_exiting_tiles, + ) = torch.ops.trtllm.moe_sort( + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + num_experts=num_experts, + top_k=top_k, + local_expert_offset=0, + local_num_experts=num_local_experts, + tile_tokens_dim=tile_size, ) - tile_idx_to_mn_limit_list = helper.generate_tile_idx_to_mn_limit(num_tokens_per_expert.tolist()) - for idx, (group_idx, mn_limit) in enumerate( - zip(tile_idx_to_group_idx_list, tile_idx_to_mn_limit_list) - ): - tile_idx_to_group_idx[idx] = group_idx - tile_idx_to_mn_limit[idx] = mn_limit - - tile_idx_to_group_idx = tile_idx_to_group_idx.cuda() - tile_idx_to_mn_limit = tile_idx_to_mn_limit.cuda() - - # Generate permuted_idx_to_expanded_idx for gather operation - permuted_idx_to_expanded_idx_list = helper.generate_permuted_idx_to_expanded_idx( - num_tokens, num_tokens_per_expert.tolist(), max_num_permuted_tokens - ) - permuted_idx_to_expanded_idx = torch.tensor( - permuted_idx_to_expanded_idx_list, dtype=torch.int32, device="cuda" - ) - assert permuted_idx_to_expanded_idx.size(0) == max_num_permuted_tokens + max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0) + num_valid_permuted_tokens = total_num_padded_tokens.item() # Create input tensors (original size, not permuted) a = torch.randint(-5, 5, (num_tokens, hidden_size), dtype=torch.int32, device="cuda").to( @@ -826,18 +797,22 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( ) # Compute reference: manually gather, compute GEMM, apply SwiGLU, then quantize - a_gathered = torch.empty( - max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype, device=a.device - ) + permuted_idx_to_expanded_idx_list = permuted_idx_to_expanded_idx.cpu().tolist() + tile_idx_to_mn_limit_list = tile_idx_to_mn_limit.cpu().tolist() + + a_gathered = torch.empty(max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype) a_sf_gathered = torch.empty( - max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype, device=a_sf.device + max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype ) for i in range(num_valid_permuted_tokens): - expanded_idx = permuted_idx_to_expanded_idx[i].item() - if expanded_idx != helper.pad_val: - token_id = expanded_idx // top_k - a_gathered[i] = a[token_id] - a_sf_gathered[i] = a_sf_unswizzled[token_id] + if i >= tile_idx_to_mn_limit_list[i // tile_size]: + continue + expanded_idx = permuted_idx_to_expanded_idx_list[i] + token_id = expanded_idx // top_k + a_gathered[i] = a[token_id] + a_sf_gathered[i] = a_sf_unswizzled[token_id] + a_gathered = a_gathered.to(a.device) + a_sf_gathered = a_sf_gathered.to(a.device) # Swizzle a_sf_gathered for reference GEMM a_sf_gathered_swizzled = swizzle_sf( @@ -886,8 +861,9 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( # Create mask for valid tokens valid_token_mask = torch.zeros(num_valid_permuted_tokens, dtype=torch.bool, device="cuda") for i in range(num_valid_permuted_tokens): - if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: - valid_token_mask[i] = True + if i >= tile_idx_to_mn_limit_list[i // tile_size]: + continue + valid_token_mask[i] = True num_valid_tokens = valid_token_mask.sum().item() if num_valid_tokens > 0: @@ -905,9 +881,10 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( c_sf_valid = [] c_sf_ref_valid = [] for i in range(num_valid_permuted_tokens): - if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: - c_sf_valid.append(c_sf_unswizzled[i]) - c_sf_ref_valid.append(c_sf_ref_unswizzled[i]) + if i >= tile_idx_to_mn_limit_list[i // tile_size]: + continue + c_sf_valid.append(c_sf_unswizzled[i]) + c_sf_ref_valid.append(c_sf_ref_unswizzled[i]) c_sf_valid = torch.cat(c_sf_valid) c_sf_ref_valid = torch.cat(c_sf_ref_valid)