[TRTLLM-9819][perf] Reuse alltoall workspace for CuteDSL MoE output (#9840)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-12-19 02:36:38 +08:00 committed by GitHub
parent 0b279f4ad4
commit 6fe89ea00f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 98 additions and 54 deletions

View File

@ -205,24 +205,26 @@ std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Ten
// Unpermute
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
torch::Tensor const& topk_scales)
void moe_unpermute_inplace(torch::Tensor const& permuted_input, torch::Tensor const& output,
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& topk_scales)
{
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
int64_t const max_num_permuted_tokens = permuted_input.size(0);
int64_t const hidden_size = permuted_input.size(1);
TORCH_CHECK(output.dim() == 2, "output must be 2D.");
int64_t const num_tokens = output.size(0);
TORCH_CHECK(output.size(1) == hidden_size, "output.size(1) must be hidden_size.");
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
TORCH_CHECK(
expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens.");
int64_t const top_k = expanded_idx_to_permuted_idx.size(1);
TORCH_CHECK(topk_scales.dim() == 2, "topk_scales must be 2D.");
TORCH_CHECK(topk_scales.size(0) == num_tokens, "topk_scales.size(0) must be num_tokens.");
TORCH_CHECK(topk_scales.size(1) == top_k, "topk_scales.size(1) must be top_k.");
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
auto output
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
auto const& stream = at::cuda::getCurrentCUDAStream(permuted_input.get_device());
#define DISPATCH_MOE_UNPERMUTE(InputType, TopKScaleType) \
@ -253,7 +255,19 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
}
#undef DISPATCH_MOE_UNPERMUTE
}
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
torch::Tensor const& topk_scales)
{
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
int64_t const hidden_size = permuted_input.size(1);
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
auto output
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
moe_unpermute_inplace(permuted_input, output, expanded_idx_to_permuted_idx, topk_scales);
return output;
}
@ -489,6 +503,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
m.def(
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
m.def(
"moe_unpermute_inplace(Tensor permuted_input, Tensor(a!) output, Tensor expanded_idx_to_permuted_idx, Tensor "
"topk_scales) -> ()");
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
m.def(
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
@ -510,6 +527,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("moe_topk_sort", &tensorrt_llm::torch_ext::moe_topk_sort);
m.impl("moe_sort", &tensorrt_llm::torch_ext::moe_sort);
m.impl("moe_permute", &tensorrt_llm::torch_ext::moe_permute);
m.impl("moe_unpermute_inplace", &tensorrt_llm::torch_ext::moe_unpermute_inplace);
m.impl("moe_unpermute", &tensorrt_llm::torch_ext::moe_unpermute);
m.impl("moe_output_memset_inplace", &tensorrt_llm::torch_ext::moe_output_memset_inplace);
m.impl("moe_swiglu", &tensorrt_llm::torch_ext::moe_swiglu);

View File

@ -77,6 +77,9 @@ def inplace_info():
torch.ops.trtllm.logits_bitmask.default: {
1: "logits"
},
torch.ops.trtllm.moe_unpermute_inplace.default: {
2: "output"
},
torch.ops.trtllm.moe_output_memset_inplace.default: {
1: "input"
},

View File

@ -991,23 +991,17 @@ class ConfigurableMoE(MoE):
if not isinstance(self.comm, NVLinkOneSided):
return None
if not self.backend.supports_moe_output_in_alltoall_workspace():
# Ensure payload_in_workspace is False if backend doesn't support it
self.comm.payload_in_workspace = False
return None
# Determine workspace dtype and whether backend supports workspace output
workspace_dtype = output_dtype
backend_supports_workspace = False
if isinstance(self.backend, TRTLLMGenFusedMoE):
# TRTLLMGen specific configuration
self.comm.invalid_token_expert_id = -1
workspace_dtype = torch.bfloat16
backend_supports_workspace = self.backend.has_w4a8_mxfp4_mxfp8
elif isinstance(self.backend, CutlassFusedMoE):
# Cutlass always supports workspace output with NVLinkOneSided
backend_supports_workspace = True
if not backend_supports_workspace:
# Ensure payload_in_workspace is False if backend doesn't support it
self.comm.payload_in_workspace = False
return None
# Calculate runtime max tokens per rank
assert all_rank_num_tokens is not None, (
@ -1022,7 +1016,6 @@ class ConfigurableMoE(MoE):
# Dynamically enable payload_in_workspace for this forward pass
self.comm.payload_in_workspace = True
return moe_output
def _get_backend_kwargs(
@ -1096,13 +1089,18 @@ class ConfigurableMoE(MoE):
# Get moe_output for NVLinkOneSided backend
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
all_rank_num_tokens, output_dtype
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
)
# CuteDSL-specific parameters
elif self.backend.__class__ == CuteDslFusedMoE:
kwargs["enable_alltoall"] = self.enable_alltoall
# Get moe_output for NVLinkOneSided backend
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
)
# DeepGemm-specific parameters
elif self.backend.__class__ == DeepGemmFusedMoE:
if workspace is not None:
@ -1123,7 +1121,7 @@ class ConfigurableMoE(MoE):
# Get moe_output for NVLinkOneSided backend
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
all_rank_num_tokens, output_dtype
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
)
return kwargs

View File

@ -210,6 +210,9 @@ class CuteDslFusedMoE(CutlassFusedMoE):
return NVFP4CuteDslFusedMoEMethod()
return super()._get_quant_method()
def supports_moe_output_in_alltoall_workspace(self):
return self.has_nvfp4
def quantize_input(self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
post_quant_comm: bool = True):
@ -258,6 +261,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
token_selected_experts: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
x_sf: Optional[torch.Tensor] = None,
moe_output: Optional[torch.Tensor] = None,
enable_alltoall: bool = False,
) -> torch.Tensor:
assert self.has_nvfp4
@ -274,6 +278,16 @@ class CuteDslFusedMoE(CutlassFusedMoE):
tile_tokens_dim=tile_size,
)
if moe_output is None:
moe_output = torch.empty(
(token_final_scales.size(0), self.hidden_size),
dtype=output_dtype,
device=x.device)
else:
assert moe_output.size() == (token_final_scales.size(0),
self.hidden_size)
assert moe_output.dtype == output_dtype
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
@ -291,12 +305,10 @@ class CuteDslFusedMoE(CutlassFusedMoE):
local_expert_offset=self.slot_start,
tile_size=tile_size,
)
if self.use_fused_finalize:
output = torch.empty((token_final_scales.size(0), self.hidden_size),
dtype=output_dtype,
device=x.device)
torch.ops.trtllm.moe_output_memset_inplace(
input=output,
input=moe_output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
@ -313,7 +325,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
weight_scale=self.quant_scales.fc2_weight_block.view(
torch.uint8),
alpha=self.quant_scales.fc2_global,
output=output,
output=moe_output,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
@ -326,7 +338,6 @@ class CuteDslFusedMoE(CutlassFusedMoE):
tile_size=tile_size,
output_dtype=output_dtype,
)
x = output
else:
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
@ -344,12 +355,13 @@ class CuteDslFusedMoE(CutlassFusedMoE):
tile_size=tile_size,
output_dtype=output_dtype,
)
x = torch.ops.trtllm.moe_unpermute(
torch.ops.trtllm.moe_unpermute_inplace(
permuted_input=x,
output=moe_output,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
topk_scales=token_final_scales,
)
return x
return moe_output
def run_moe_fp8_block_scales(
self,
@ -364,12 +376,12 @@ class CuteDslFusedMoE(CutlassFusedMoE):
weight_dtype = self.w3_w1_weight.dtype
(
permuted_row_to_unpermuted_row_tensor,
permuted_token_selected_experts_tensor,
permuted_data_tensor,
expert_first_token_offset_tensor,
permuted_token_final_scales_tensor,
unpermuted_row_to_permuted_row_tensor,
permuted_row_to_unpermuted_row,
permuted_token_selected_experts,
x,
expert_first_token_offset,
permuted_token_final_scales,
unpermuted_row_to_permuted_row,
) = torch.ops.trtllm.moe_permute_op(
x,
token_selected_experts,
@ -388,35 +400,34 @@ class CuteDslFusedMoE(CutlassFusedMoE):
min_latency_mode=False,
use_fp8_block_scaling=True,
)
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
permuted_data_tensor)
h1 = cute_dsl_fp8_group_blockwise_gemm_ref(
a=act_input_fp8,
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w3_w1_weight.view(weight_dtype),
a_sf=act_input_sf,
a_sf=x_sf,
b_sf=self.quant_scales[0],
offset_array=expert_first_token_offset_tensor,
offset_array=expert_first_token_offset,
)
h2 = swiglu_fused_moe(h1)
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2)
h3 = cute_dsl_fp8_group_blockwise_gemm_ref(
a=act_input_fp8,
x = swiglu_fused_moe(x)
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w2_weight.view(weight_dtype),
a_sf=act_input_sf,
a_sf=x_sf,
b_sf=self.quant_scales[1],
offset_array=expert_first_token_offset_tensor,
offset_array=expert_first_token_offset,
)
h4 = torch.ops.trtllm.moe_finalize_scale_op(
h3,
x = torch.ops.trtllm.moe_finalize_scale_op(
x,
None, # biases
token_final_scales,
unpermuted_row_to_permuted_row_tensor,
permuted_row_to_unpermuted_row_tensor,
unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row,
token_selected_experts,
expert_first_token_offset_tensor,
expert_first_token_offset,
enable_alltoall,
x.shape[0], # num_rows
x.shape[1], # (possibly padded) hidden_size
token_final_scales.size(0), # num_rows
self.hidden_size, # (possibly padded) hidden_size
self.unpadded_hidden_size, # original hidden size
self.routing_method.top_k,
self.expert_size_per_partition, # num_experts_per_node
@ -425,7 +436,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
self.ep_size,
self.ep_rank,
)
return h4
return x
def run_moe(
self,
@ -433,6 +444,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
token_selected_experts: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
x_sf: Optional[torch.Tensor] = None,
moe_output: Optional[torch.Tensor] = None,
enable_alltoall: bool = False,
) -> torch.Tensor:
"""
@ -448,6 +460,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
this represents expert slots [num_tokens, top_k] instead.
token_final_scales: Final scaling factors for each token
x_sf: Input scale factors (optional, for certain quantization schemes)
moe_output: Pre-allocated MoE output buffer (optional, for NVLINK one-sided backend).
enable_alltoall: Whether alltoall communication is enabled.
Returns:
@ -459,6 +472,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
x_sf=x_sf,
moe_output=moe_output,
enable_alltoall=enable_alltoall)
elif self.has_deepseek_fp8_block_scales:
return self.run_moe_fp8_block_scales(

View File

@ -389,6 +389,9 @@ class CutlassFusedMoE(MoE):
self._weights_created = True
self._check_configs()
def supports_moe_output_in_alltoall_workspace(self):
return True
def run_moe(
self,
x: torch.Tensor,

View File

@ -354,6 +354,9 @@ class TRTLLMGenFusedMoE(MoE):
return x, x_sf
def supports_moe_output_in_alltoall_workspace(self):
return self.has_w4a8_mxfp4_mxfp8
def run_moe(
self,
x: torch.Tensor,

View File

@ -723,6 +723,11 @@ class MoE(nn.Module):
def expand_intermediate_size_per_partition(self):
return self.intermediate_size_per_partition * self.intermediate_size_expand_ratio
def supports_moe_output_in_alltoall_workspace(self):
""" Supports moe_output in alltoall workspace
"""
return False
def reducescatter_or_allreduce(
self,
inputs,