mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10147][perf] Balanced random MoE workload generator for CuteDSL kernel UT, autotuner and layerwise benchmark (#10279)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
fd7fd8c39d
commit
72ef732bcf
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -31,13 +32,19 @@ class GroupedGemmInputsHelper:
|
||||
IDX_A = 0
|
||||
IDX_SHAPE_INFER = IDX_A # Default: use a tensor for shape inference
|
||||
|
||||
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
|
||||
local_expert_offset: int, tile_size: int):
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
tile_size: int,
|
||||
seed: int = 515):
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
self.tile_size = tile_size
|
||||
self.seed = seed
|
||||
# Padding values should never be accessed.
|
||||
# Intentionally use a large padding value to expose issues early.
|
||||
self.pad_val = int(2e9)
|
||||
@ -82,118 +89,134 @@ class GroupedGemmInputsHelper:
|
||||
self, input_shapes: List[torch.Size]) -> int:
|
||||
return self.infer_shape_max_num_tiles(input_shapes) * self.tile_size
|
||||
|
||||
def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]:
|
||||
average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts
|
||||
balance = 0
|
||||
num_tokens_per_expert = []
|
||||
for i in range(self.num_local_experts):
|
||||
balance += average_num_tokens_per_expert
|
||||
if balance <= 1e-3:
|
||||
continue
|
||||
curr_num_tokens = int(balance) + 1
|
||||
num_tokens_per_expert.append(curr_num_tokens)
|
||||
balance -= curr_num_tokens
|
||||
def generate_num_tokens_per_expert(self,
|
||||
num_tokens: int,
|
||||
approx_max_load: bool = False
|
||||
) -> List[int]:
|
||||
ep_size = self.num_experts // self.num_local_experts
|
||||
average_num_tokens_per_rank = num_tokens * self.top_k / ep_size
|
||||
|
||||
if approx_max_load:
|
||||
# https://en.wikipedia.org/wiki/Balls_into_bins_problem
|
||||
# The constant c can be measured empirically, we choose 1.0 for simplicity.
|
||||
c = 1.0
|
||||
extra_num_tokens_on_curr_rank = c * math.sqrt(
|
||||
average_num_tokens_per_rank * math.log(ep_size))
|
||||
num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank +
|
||||
extra_num_tokens_on_curr_rank)
|
||||
else:
|
||||
num_tokens_on_curr_rank = math.ceil(average_num_tokens_per_rank)
|
||||
|
||||
num_tokens_on_curr_rank = min(num_tokens * self.top_k,
|
||||
num_tokens_on_curr_rank)
|
||||
|
||||
base, remainder = divmod(num_tokens_on_curr_rank,
|
||||
self.num_local_experts)
|
||||
num_tokens_per_expert = [base + 1] * remainder + [base] * (
|
||||
self.num_local_experts - remainder)
|
||||
assert len(num_tokens_per_expert) == self.num_local_experts
|
||||
assert sum(num_tokens_per_expert) == num_tokens_on_curr_rank
|
||||
return num_tokens_per_expert
|
||||
|
||||
def generate_tile_idx_to_group_idx(
|
||||
self, num_tokens_per_expert: List[int]) -> List[int]:
|
||||
tile_idx_to_group_idx = []
|
||||
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
|
||||
curr_num_tiles = (curr_num_tokens + self.tile_size -
|
||||
1) // self.tile_size
|
||||
tile_idx_to_group_idx.extend([i] * curr_num_tiles)
|
||||
return tile_idx_to_group_idx
|
||||
|
||||
def generate_tile_idx_to_mn_limit(
|
||||
self, num_tokens_per_expert: List[int]) -> List[int]:
|
||||
tile_idx_to_mn_limit = []
|
||||
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
|
||||
curr_num_tiles = (curr_num_tokens + self.tile_size -
|
||||
1) // self.tile_size
|
||||
prev_mn_limit = len(tile_idx_to_mn_limit) * self.tile_size
|
||||
for j in range(curr_num_tiles):
|
||||
tile_idx_to_mn_limit.append(prev_mn_limit + min(
|
||||
(j + 1) * self.tile_size, curr_num_tokens))
|
||||
return tile_idx_to_mn_limit
|
||||
|
||||
def generate_permuted_idx_to_expanded_idx(
|
||||
def generate_token_selected_experts(
|
||||
self, num_tokens: int,
|
||||
num_tokens_per_expert: List[int]) -> List[int]:
|
||||
permuted_idx_to_expanded_idx = []
|
||||
colmajor_expanded_idx = 0
|
||||
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
|
||||
curr_num_tiles = (curr_num_tokens + self.tile_size -
|
||||
1) // self.tile_size
|
||||
for j in range(curr_num_tiles * self.tile_size):
|
||||
if j < curr_num_tokens:
|
||||
token_idx = colmajor_expanded_idx % num_tokens
|
||||
topk_idx = colmajor_expanded_idx // num_tokens
|
||||
expanded_idx = token_idx * self.top_k + topk_idx
|
||||
permuted_idx_to_expanded_idx.append(expanded_idx)
|
||||
colmajor_expanded_idx += 1
|
||||
else:
|
||||
permuted_idx_to_expanded_idx.append(self.pad_val)
|
||||
return permuted_idx_to_expanded_idx
|
||||
num_tokens_per_expert: List[int]) -> torch.Tensor:
|
||||
"""Balanced random based on rejection sampling.
|
||||
"""
|
||||
token_selected_experts = -torch.ones(
|
||||
num_tokens, self.top_k, dtype=torch.int32)
|
||||
num_selected_experts = torch.zeros(num_tokens, dtype=torch.int32)
|
||||
|
||||
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||
torch.manual_seed(self.seed)
|
||||
selection_orders = [
|
||||
torch.randperm(num_tokens)
|
||||
for _ in range(self.num_local_experts)
|
||||
]
|
||||
|
||||
for j, num_tokens_j in enumerate(num_tokens_per_expert):
|
||||
selection_order_j = selection_orders[j].tolist()
|
||||
prioritized = torch.nonzero(num_selected_experts <= (
|
||||
self.top_k - (self.num_experts - j))).squeeze(-1).tolist()
|
||||
if len(prioritized) > 0:
|
||||
selection_order_j = prioritized + [
|
||||
i for i in selection_order_j if i not in prioritized
|
||||
]
|
||||
for i in selection_order_j:
|
||||
if num_selected_experts[i] < self.top_k:
|
||||
token_selected_experts[
|
||||
i,
|
||||
num_selected_experts[i]] = j + self.local_expert_offset
|
||||
num_selected_experts[i] += 1
|
||||
num_tokens_j -= 1
|
||||
if num_tokens_j <= 0:
|
||||
break
|
||||
|
||||
assert ((token_selected_experts
|
||||
>= 0).sum(dim=-1) == num_selected_experts).all().item()
|
||||
if self.num_local_experts == self.num_experts:
|
||||
assert (num_selected_experts == self.top_k).all().item()
|
||||
else:
|
||||
assert (num_selected_experts <= self.top_k).all().item()
|
||||
return token_selected_experts
|
||||
|
||||
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs
|
||||
num_tokens = self.infer_num_tokens(a.size(0))
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
|
||||
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
|
||||
num_tokens_per_expert)
|
||||
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
|
||||
num_padding_tiles_val = tile_idx_to_group_idx.size(
|
||||
0) - num_non_exiting_tiles_val
|
||||
assert num_non_exiting_tiles_val > 0
|
||||
assert num_padding_tiles_val >= 0
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(
|
||||
num_tokens, approx_max_load=True)
|
||||
token_selected_experts = self.generate_token_selected_experts(
|
||||
num_tokens, num_tokens_per_expert)
|
||||
|
||||
tile_idx_to_group_idx = torch.tensor(
|
||||
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
|
||||
dtype=tile_idx_to_group_idx.dtype,
|
||||
device=tile_idx_to_group_idx.device)
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val],
|
||||
dtype=num_non_exiting_tiles.dtype,
|
||||
device=num_non_exiting_tiles.device)
|
||||
token_selected_experts = token_selected_experts.cuda()
|
||||
token_final_scales = torch.ones_like(token_selected_experts,
|
||||
dtype=torch.float32)
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
local_num_experts=self.num_local_experts,
|
||||
tile_tokens_dim=self.tile_size,
|
||||
)
|
||||
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others
|
||||
|
||||
def inputs_pre_hook_finalize_fusion(
|
||||
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
|
||||
num_tokens = self.infer_num_tokens(a.size(0))
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
|
||||
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
|
||||
num_tokens_per_expert)
|
||||
tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit(
|
||||
num_tokens_per_expert)
|
||||
permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx(
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(
|
||||
num_tokens, approx_max_load=True)
|
||||
token_selected_experts = self.generate_token_selected_experts(
|
||||
num_tokens, num_tokens_per_expert)
|
||||
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
|
||||
num_padding_tiles_val = tile_idx_to_group_idx.size(
|
||||
0) - num_non_exiting_tiles_val
|
||||
assert num_non_exiting_tiles_val > 0
|
||||
assert num_padding_tiles_val >= 0
|
||||
assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val
|
||||
assert len(permuted_idx_to_expanded_idx_list
|
||||
) == num_non_exiting_tiles_val * self.tile_size
|
||||
|
||||
tile_idx_to_group_idx = torch.tensor(
|
||||
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
|
||||
dtype=tile_idx_to_group_idx.dtype,
|
||||
device=tile_idx_to_group_idx.device)
|
||||
tile_idx_to_mn_limit = torch.tensor(
|
||||
tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val,
|
||||
dtype=tile_idx_to_mn_limit.dtype,
|
||||
device=tile_idx_to_mn_limit.device)
|
||||
permuted_idx_to_expanded_idx = torch.tensor(
|
||||
permuted_idx_to_expanded_idx_list + [self.pad_val] *
|
||||
(num_padding_tiles_val * self.tile_size),
|
||||
dtype=permuted_idx_to_expanded_idx.dtype,
|
||||
device=permuted_idx_to_expanded_idx.device)
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val],
|
||||
dtype=num_non_exiting_tiles.dtype,
|
||||
device=num_non_exiting_tiles.device)
|
||||
token_selected_experts = token_selected_experts.cuda()
|
||||
token_final_scales = torch.ones_like(token_selected_experts,
|
||||
dtype=torch.float32)
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
local_num_experts=self.num_local_experts,
|
||||
tile_tokens_dim=self.tile_size,
|
||||
)
|
||||
return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
|
||||
|
||||
|
||||
@ -221,44 +244,6 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper):
|
||||
IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7
|
||||
IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX
|
||||
|
||||
def generate_permuted_idx_to_expanded_idx(
|
||||
self, num_tokens: int, num_tokens_per_expert: List[int],
|
||||
max_num_permuted_tokens: int) -> List[int]:
|
||||
"""Generate permuted_idx_to_expanded_idx for gather operation.
|
||||
|
||||
Maps permuted index to expanded index (token_idx * top_k + topk_idx).
|
||||
|
||||
Args:
|
||||
num_tokens: Total number of input tokens
|
||||
num_tokens_per_expert: List of token counts per expert
|
||||
max_num_permuted_tokens: Target size of the output list
|
||||
|
||||
Returns:
|
||||
List of expanded IDs with length = max_num_permuted_tokens,
|
||||
where permuted_idx_to_expanded_idx[permuted_idx] = expanded_idx
|
||||
Padding tokens are marked with pad_val
|
||||
Note: In kernel, use expanded_idx // top_k to get original token_idx
|
||||
"""
|
||||
permuted_idx_to_expanded_idx = []
|
||||
colmajor_expanded_idx = 0
|
||||
for i, curr_num_tokens in enumerate(num_tokens_per_expert):
|
||||
curr_num_tiles = (curr_num_tokens + self.tile_size -
|
||||
1) // self.tile_size
|
||||
for j in range(curr_num_tiles * self.tile_size):
|
||||
if j < curr_num_tokens:
|
||||
token_idx = colmajor_expanded_idx % num_tokens
|
||||
topk_idx = colmajor_expanded_idx // num_tokens
|
||||
expanded_idx = token_idx * self.top_k + topk_idx
|
||||
permuted_idx_to_expanded_idx.append(expanded_idx)
|
||||
colmajor_expanded_idx += 1
|
||||
else:
|
||||
permuted_idx_to_expanded_idx.append(
|
||||
self.pad_val) # Padding token
|
||||
# Pad to max_num_permuted_tokens
|
||||
while len(permuted_idx_to_expanded_idx) < max_num_permuted_tokens:
|
||||
permuted_idx_to_expanded_idx.append(self.pad_val)
|
||||
return permuted_idx_to_expanded_idx
|
||||
|
||||
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
"""Pre-hook for gather-based SwiGLU fusion kernel.
|
||||
|
||||
@ -276,39 +261,31 @@ class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper):
|
||||
IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx
|
||||
|
||||
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
|
||||
max_num_tiles = max_num_permuted_tokens // self.tile_size
|
||||
|
||||
num_tokens = self.infer_num_tokens(max_num_permuted_tokens)
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
|
||||
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
|
||||
num_tokens_per_expert)
|
||||
tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit(
|
||||
num_tokens_per_expert)
|
||||
permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx(
|
||||
num_tokens, num_tokens_per_expert, max_num_permuted_tokens)
|
||||
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
|
||||
num_padding_tiles_val = max_num_tiles - num_non_exiting_tiles_val
|
||||
assert num_non_exiting_tiles_val > 0
|
||||
assert num_padding_tiles_val >= 0
|
||||
assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val
|
||||
assert len(permuted_idx_to_expanded_idx_list) == max_num_permuted_tokens
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(
|
||||
num_tokens, approx_max_load=True)
|
||||
token_selected_experts = self.generate_token_selected_experts(
|
||||
num_tokens, num_tokens_per_expert)
|
||||
|
||||
tile_idx_to_group_idx = torch.tensor(
|
||||
tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val,
|
||||
dtype=tile_idx_to_group_idx.dtype,
|
||||
device=tile_idx_to_group_idx.device)
|
||||
tile_idx_to_mn_limit = torch.tensor(
|
||||
tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val,
|
||||
dtype=tile_idx_to_mn_limit.dtype,
|
||||
device=tile_idx_to_mn_limit.device)
|
||||
permuted_idx_to_expanded_idx = torch.tensor(
|
||||
permuted_idx_to_expanded_idx_list,
|
||||
dtype=permuted_idx_to_expanded_idx.dtype,
|
||||
device=permuted_idx_to_expanded_idx.device)
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val],
|
||||
dtype=num_non_exiting_tiles.dtype,
|
||||
device=num_non_exiting_tiles.device)
|
||||
token_selected_experts = token_selected_experts.cuda()
|
||||
token_final_scales = torch.ones_like(token_selected_experts,
|
||||
dtype=torch.float32)
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
local_num_experts=self.num_local_experts,
|
||||
tile_tokens_dim=self.tile_size,
|
||||
)
|
||||
return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit, permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles, global_sf)
|
||||
@ -924,6 +901,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
5, 0,
|
||||
helper.infer_shape_max_num_tiles)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
use_cold_l2_cache=True,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
@ -1222,6 +1200,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
8, 0, helper.infer_shape_max_num_permuted_tokens),
|
||||
ConstraintSpec(10, 0, helper.infer_shape_num_tokens)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion,
|
||||
use_cold_l2_cache=True,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
@ -1602,6 +1581,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
5, 0,
|
||||
helper.infer_shape_max_num_tiles)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
use_cold_l2_cache=True,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
@ -1931,6 +1911,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
6, 0,
|
||||
helper.infer_shape_max_num_tiles)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
use_cold_l2_cache=True,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
|
||||
@ -175,7 +175,8 @@ class CuteDslFusedMoENvfp4InputsHelper(GroupedGemmInputsHelper):
|
||||
def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
x, token_selected_experts, *others = inputs
|
||||
num_tokens = token_selected_experts.size(0)
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(
|
||||
num_tokens, approx_max_load=True)
|
||||
|
||||
new_token_selected_experts = []
|
||||
for i, curr_num_tokens in enumerate(num_tokens_per_expert,
|
||||
@ -258,6 +259,7 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
|
||||
ConstraintSpec(
|
||||
4, 0, helper.infer_shape_num_tokens)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
use_cold_l2_cache=True,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from enum import IntEnum
|
||||
@ -11,6 +12,7 @@ import torch
|
||||
import tensorrt_llm._torch.model_config
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward
|
||||
@ -54,8 +56,15 @@ def round_up(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
|
||||
def get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
def get_balanced_selection_impl_default(
|
||||
num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
|
||||
num_tokens, top_k
|
||||
@ -68,6 +77,32 @@ def get_balanced_selection_no_cache(
|
||||
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
||||
|
||||
|
||||
def get_balanced_selection_impl_random(
|
||||
num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_experts, 0, 128)
|
||||
num_tokens_per_expert = helper.generate_num_tokens_per_expert(num_tokens, approx_max_load=False)
|
||||
assert sum(num_tokens_per_expert) == num_tokens * top_k
|
||||
token_selected_experts = helper.generate_token_selected_experts(
|
||||
num_tokens, num_tokens_per_expert
|
||||
)
|
||||
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
||||
|
||||
|
||||
def get_balanced_selection_no_cache(*args, **kwargs):
|
||||
if os.environ.get("TRTLLM_LAYERWISE_BENCHMARK_BALANCED_IMPL", "DEFAULT") == "RANDOM":
|
||||
return get_balanced_selection_impl_random(*args, **kwargs)
|
||||
else:
|
||||
return get_balanced_selection_impl_default(*args, **kwargs)
|
||||
|
||||
|
||||
get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
|
||||
|
||||
|
||||
|
||||
11
tests/scripts/cute_dsl_kernels/README.md
Normal file
11
tests/scripts/cute_dsl_kernels/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Launch Scripts for CuTe DSL Kernels
|
||||
|
||||
## MoE Workload Generator
|
||||
|
||||
```bash
|
||||
# Generate workload using a balanced random method
|
||||
# Per-rank token number 128, EP size 32 (a typical workload for large EP gen phase)
|
||||
python moe_workload_generator.py --num_tokens 128 --ep_size 32 --tile_size 128
|
||||
# Per-rank token number 8192, EP size 4 (a typical workload for ctx phase)
|
||||
python moe_workload_generator.py --num_tokens 8192 --ep_size 4 --tile_size 256
|
||||
```
|
||||
176
tests/scripts/cute_dsl_kernels/moe_workload_generator.py
Normal file
176
tests/scripts/cute_dsl_kernels/moe_workload_generator.py
Normal file
@ -0,0 +1,176 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks.runner import (
|
||||
get_balanced_selection_impl_default,
|
||||
get_balanced_selection_impl_random,
|
||||
)
|
||||
|
||||
|
||||
def get_balanced_selection_impl_default_legacy(
|
||||
num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
world_size = ep_size
|
||||
rank = dp_rank
|
||||
# First, each sender selects target rank
|
||||
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view(
|
||||
num_tokens, world_size, top_k
|
||||
)
|
||||
target_rank_before_mod += top_k * torch.arange(num_tokens).view(
|
||||
num_tokens, 1, 1
|
||||
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic
|
||||
target_rank = target_rank_before_mod % world_size
|
||||
# Second, each receiver selects target expert
|
||||
target_expert = torch.empty_like(target_rank)
|
||||
for reciever_rank in range(world_size):
|
||||
mask = target_rank == reciever_rank
|
||||
experts_per_rank = num_experts // world_size
|
||||
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
|
||||
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
|
||||
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
|
||||
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
||||
|
||||
|
||||
def gen_moe_workload(
|
||||
num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
num_tokens_per_expert: Optional[List[int]],
|
||||
tile_size: int,
|
||||
method: str = "balanced_random",
|
||||
):
|
||||
if num_tokens_per_expert is not None:
|
||||
num_local_experts = len(num_tokens_per_expert)
|
||||
assert num_local_experts * ep_size == num_experts
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
token_selected_experts = helper.generate_token_selected_experts(
|
||||
num_tokens * ep_size, num_tokens_per_expert
|
||||
)
|
||||
token_selected_experts = token_selected_experts.cuda()
|
||||
else:
|
||||
if method == "balanced_random":
|
||||
get_balanced_selection_impl = get_balanced_selection_impl_random
|
||||
elif method == "balanced_default":
|
||||
get_balanced_selection_impl = get_balanced_selection_impl_default
|
||||
elif method == "balanced_default_legacy":
|
||||
get_balanced_selection_impl = get_balanced_selection_impl_default_legacy
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}.")
|
||||
|
||||
token_selected_experts = [
|
||||
get_balanced_selection_impl(
|
||||
num_tokens=num_tokens,
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
dp_size=ep_size,
|
||||
dp_rank=i,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
for i in range(ep_size)
|
||||
]
|
||||
token_selected_experts = torch.cat(token_selected_experts, dim=0)
|
||||
|
||||
assert token_selected_experts.size() == (num_tokens * ep_size, top_k)
|
||||
token_final_scales = torch.ones_like(token_selected_experts, dtype=torch.float32)
|
||||
return torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_experts // ep_size,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
|
||||
|
||||
@click.command("moe_workload_generator")
|
||||
@click.option("--num_tokens", type=int, default=128)
|
||||
@click.option("--top_k", type=int, default=8)
|
||||
@click.option("--num_experts", type=int, default=256)
|
||||
@click.option("--ep_size", type=int, default=32)
|
||||
@click.option("--num_tokens_per_expert", type=str, default=None)
|
||||
@click.option("--tile_size", type=click.Choice([128, 256]), default=128)
|
||||
@click.option(
|
||||
"--method",
|
||||
type=click.Choice(["balanced_random", "balanced_default", "balanced_default_legacy"]),
|
||||
default="balanced_random",
|
||||
)
|
||||
@click.option("--seed", type=int, default=515)
|
||||
@click.option("--output_path", type=str, default="./moe_workload")
|
||||
def main(
|
||||
num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
num_tokens_per_expert: str,
|
||||
tile_size: int,
|
||||
method: str,
|
||||
seed: int,
|
||||
output_path: str,
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if num_tokens_per_expert is not None:
|
||||
num_tokens_per_expert = [int(x) for x in num_tokens_per_expert.split(",")]
|
||||
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = gen_moe_workload(
|
||||
num_tokens=num_tokens,
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
tile_size=tile_size,
|
||||
method=method,
|
||||
)
|
||||
|
||||
if not os.path.isdir(output_path):
|
||||
os.makedirs(output_path)
|
||||
|
||||
metadata = {
|
||||
"num_tokens": num_tokens,
|
||||
"top_k": top_k,
|
||||
"num_experts": num_experts,
|
||||
"ep_size": ep_size,
|
||||
"tile_size": tile_size,
|
||||
"method": method,
|
||||
"seed": seed,
|
||||
}
|
||||
with open(f"{output_path}/metadata.json", "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
workload = {
|
||||
"tile_idx_to_group_idx": tile_idx_to_group_idx,
|
||||
"tile_idx_to_mn_limit": tile_idx_to_mn_limit,
|
||||
"expanded_idx_to_permuted_idx": expanded_idx_to_permuted_idx,
|
||||
"permuted_idx_to_expanded_idx": permuted_idx_to_expanded_idx,
|
||||
"total_num_padded_tokens": total_num_padded_tokens,
|
||||
"num_non_exiting_tiles": num_non_exiting_tiles,
|
||||
}
|
||||
safetensors.torch.save_file(workload, f"{output_path}/workload.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -2,10 +2,7 @@ import pytest
|
||||
import torch
|
||||
from utils.util import check_accuracy
|
||||
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import (
|
||||
GatherGroupedGemmInputsHelper,
|
||||
GroupedGemmInputsHelper,
|
||||
)
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref
|
||||
from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate
|
||||
from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf
|
||||
@ -740,55 +737,29 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
|
||||
# Generate routing information
|
||||
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
|
||||
_, token_selected_experts = routing_logits.topk(top_k, dim=-1)
|
||||
token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1)
|
||||
token_selected_experts = token_selected_experts.to(torch.int32)
|
||||
num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts)
|
||||
num_tokens_per_expert = num_tokens_per_expert[:num_local_experts]
|
||||
# Ensure at least one valid token
|
||||
if num_tokens_per_expert.sum().item() == 0:
|
||||
num_tokens_per_expert[0] = 1
|
||||
num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size
|
||||
num_tokens_per_expert = num_tokens_per_expert.cpu()
|
||||
num_tiles_per_expert = num_tiles_per_expert.cpu()
|
||||
num_valid_tiles = num_tiles_per_expert.sum().item()
|
||||
num_valid_permuted_tokens = num_valid_tiles * tile_size
|
||||
token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32)
|
||||
|
||||
# Create helper
|
||||
helper = GatherGroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
assert 0 <= num_valid_tiles <= max_num_tiles
|
||||
assert 0 <= num_valid_permuted_tokens <= max_num_permuted_tokens
|
||||
|
||||
# Generate tile metadata
|
||||
num_non_exiting_tiles = torch.tensor([num_valid_tiles], dtype=torch.int32, device="cuda")
|
||||
tile_idx_to_group_idx = torch.empty(max_num_tiles, dtype=torch.int32)
|
||||
tile_idx_to_mn_limit = torch.empty(max_num_tiles, dtype=torch.int32)
|
||||
tile_idx_to_group_idx.fill_(int(-2e9))
|
||||
tile_idx_to_mn_limit.fill_(int(-2e9))
|
||||
|
||||
tile_idx_to_group_idx_list = helper.generate_tile_idx_to_group_idx(
|
||||
num_tokens_per_expert.tolist()
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_local_experts,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
tile_idx_to_mn_limit_list = helper.generate_tile_idx_to_mn_limit(num_tokens_per_expert.tolist())
|
||||
|
||||
for idx, (group_idx, mn_limit) in enumerate(
|
||||
zip(tile_idx_to_group_idx_list, tile_idx_to_mn_limit_list)
|
||||
):
|
||||
tile_idx_to_group_idx[idx] = group_idx
|
||||
tile_idx_to_mn_limit[idx] = mn_limit
|
||||
|
||||
tile_idx_to_group_idx = tile_idx_to_group_idx.cuda()
|
||||
tile_idx_to_mn_limit = tile_idx_to_mn_limit.cuda()
|
||||
|
||||
# Generate permuted_idx_to_expanded_idx for gather operation
|
||||
permuted_idx_to_expanded_idx_list = helper.generate_permuted_idx_to_expanded_idx(
|
||||
num_tokens, num_tokens_per_expert.tolist(), max_num_permuted_tokens
|
||||
)
|
||||
permuted_idx_to_expanded_idx = torch.tensor(
|
||||
permuted_idx_to_expanded_idx_list, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
assert permuted_idx_to_expanded_idx.size(0) == max_num_permuted_tokens
|
||||
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
|
||||
num_valid_permuted_tokens = total_num_padded_tokens.item()
|
||||
|
||||
# Create input tensors (original size, not permuted)
|
||||
a = torch.randint(-5, 5, (num_tokens, hidden_size), dtype=torch.int32, device="cuda").to(
|
||||
@ -826,18 +797,22 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
)
|
||||
|
||||
# Compute reference: manually gather, compute GEMM, apply SwiGLU, then quantize
|
||||
a_gathered = torch.empty(
|
||||
max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype, device=a.device
|
||||
)
|
||||
permuted_idx_to_expanded_idx_list = permuted_idx_to_expanded_idx.cpu().tolist()
|
||||
tile_idx_to_mn_limit_list = tile_idx_to_mn_limit.cpu().tolist()
|
||||
|
||||
a_gathered = torch.empty(max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype)
|
||||
a_sf_gathered = torch.empty(
|
||||
max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype, device=a_sf.device
|
||||
max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype
|
||||
)
|
||||
for i in range(num_valid_permuted_tokens):
|
||||
expanded_idx = permuted_idx_to_expanded_idx[i].item()
|
||||
if expanded_idx != helper.pad_val:
|
||||
token_id = expanded_idx // top_k
|
||||
a_gathered[i] = a[token_id]
|
||||
a_sf_gathered[i] = a_sf_unswizzled[token_id]
|
||||
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
|
||||
continue
|
||||
expanded_idx = permuted_idx_to_expanded_idx_list[i]
|
||||
token_id = expanded_idx // top_k
|
||||
a_gathered[i] = a[token_id]
|
||||
a_sf_gathered[i] = a_sf_unswizzled[token_id]
|
||||
a_gathered = a_gathered.to(a.device)
|
||||
a_sf_gathered = a_sf_gathered.to(a.device)
|
||||
|
||||
# Swizzle a_sf_gathered for reference GEMM
|
||||
a_sf_gathered_swizzled = swizzle_sf(
|
||||
@ -886,8 +861,9 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
# Create mask for valid tokens
|
||||
valid_token_mask = torch.zeros(num_valid_permuted_tokens, dtype=torch.bool, device="cuda")
|
||||
for i in range(num_valid_permuted_tokens):
|
||||
if permuted_idx_to_expanded_idx[i].item() != helper.pad_val:
|
||||
valid_token_mask[i] = True
|
||||
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
|
||||
continue
|
||||
valid_token_mask[i] = True
|
||||
|
||||
num_valid_tokens = valid_token_mask.sum().item()
|
||||
if num_valid_tokens > 0:
|
||||
@ -905,9 +881,10 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
c_sf_valid = []
|
||||
c_sf_ref_valid = []
|
||||
for i in range(num_valid_permuted_tokens):
|
||||
if permuted_idx_to_expanded_idx[i].item() != helper.pad_val:
|
||||
c_sf_valid.append(c_sf_unswizzled[i])
|
||||
c_sf_ref_valid.append(c_sf_ref_unswizzled[i])
|
||||
if i >= tile_idx_to_mn_limit_list[i // tile_size]:
|
||||
continue
|
||||
c_sf_valid.append(c_sf_unswizzled[i])
|
||||
c_sf_ref_valid.append(c_sf_ref_unswizzled[i])
|
||||
|
||||
c_sf_valid = torch.cat(c_sf_valid)
|
||||
c_sf_ref_valid = torch.cat(c_sf_ref_valid)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user