From 6fe89ea00f2fab394b44b1b91ceed6266a84cff2 Mon Sep 17 00:00:00 2001 From: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Date: Fri, 19 Dec 2025 02:36:38 +0800 Subject: [PATCH] [TRTLLM-9819][perf] Reuse alltoall workspace for CuteDSL MoE output (#9840) Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --- cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp | 30 +++++-- tensorrt_llm/_torch/compilation/utils.py | 3 + .../modules/fused_moe/configurable_moe.py | 26 +++--- .../modules/fused_moe/fused_moe_cute_dsl.py | 82 +++++++++++-------- .../modules/fused_moe/fused_moe_cutlass.py | 3 + .../modules/fused_moe/fused_moe_trtllm_gen.py | 3 + .../_torch/modules/fused_moe/interface.py | 5 ++ 7 files changed, 98 insertions(+), 54 deletions(-) diff --git a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp index 770c1459f9..5f17e2372b 100644 --- a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp +++ b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp @@ -205,24 +205,26 @@ std::tuple> moe_permute(torch::Ten // Unpermute -torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx, - torch::Tensor const& topk_scales) +void moe_unpermute_inplace(torch::Tensor const& permuted_input, torch::Tensor const& output, + torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& topk_scales) { TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D."); int64_t const max_num_permuted_tokens = permuted_input.size(0); int64_t const hidden_size = permuted_input.size(1); + TORCH_CHECK(output.dim() == 2, "output must be 2D."); + int64_t const num_tokens = output.size(0); + TORCH_CHECK(output.size(1) == hidden_size, "output.size(1) must be hidden_size."); + TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D."); - int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0); + TORCH_CHECK( + expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens."); int64_t const top_k = expanded_idx_to_permuted_idx.size(1); TORCH_CHECK(topk_scales.dim() == 2, "topk_scales must be 2D."); TORCH_CHECK(topk_scales.size(0) == num_tokens, "topk_scales.size(0) must be num_tokens."); TORCH_CHECK(topk_scales.size(1) == top_k, "topk_scales.size(1) must be top_k."); - TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k, "max_num_permuted_tokens must be greater than or equal to num_tokens * top_k."); - auto output - = torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA)); auto const& stream = at::cuda::getCurrentCUDAStream(permuted_input.get_device()); #define DISPATCH_MOE_UNPERMUTE(InputType, TopKScaleType) \ @@ -253,7 +255,19 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c } #undef DISPATCH_MOE_UNPERMUTE +} +torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx, + torch::Tensor const& topk_scales) +{ + TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D."); + int64_t const hidden_size = permuted_input.size(1); + TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D."); + int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0); + + auto output + = torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA)); + moe_unpermute_inplace(permuted_input, output, expanded_idx_to_permuted_idx, topk_scales); return output; } @@ -489,6 +503,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) m.def( "moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, " "Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)"); + m.def( + "moe_unpermute_inplace(Tensor permuted_input, Tensor(a!) output, Tensor expanded_idx_to_permuted_idx, Tensor " + "topk_scales) -> ()"); m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor"); m.def( "moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, " @@ -510,6 +527,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("moe_topk_sort", &tensorrt_llm::torch_ext::moe_topk_sort); m.impl("moe_sort", &tensorrt_llm::torch_ext::moe_sort); m.impl("moe_permute", &tensorrt_llm::torch_ext::moe_permute); + m.impl("moe_unpermute_inplace", &tensorrt_llm::torch_ext::moe_unpermute_inplace); m.impl("moe_unpermute", &tensorrt_llm::torch_ext::moe_unpermute); m.impl("moe_output_memset_inplace", &tensorrt_llm::torch_ext::moe_output_memset_inplace); m.impl("moe_swiglu", &tensorrt_llm::torch_ext::moe_swiglu); diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index 2dc6914bc2..07430c6017 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -77,6 +77,9 @@ def inplace_info(): torch.ops.trtllm.logits_bitmask.default: { 1: "logits" }, + torch.ops.trtllm.moe_unpermute_inplace.default: { + 2: "output" + }, torch.ops.trtllm.moe_output_memset_inplace.default: { 1: "input" }, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 12e1eb3ca0..7aa51a938d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -991,23 +991,17 @@ class ConfigurableMoE(MoE): if not isinstance(self.comm, NVLinkOneSided): return None + if not self.backend.supports_moe_output_in_alltoall_workspace(): + # Ensure payload_in_workspace is False if backend doesn't support it + self.comm.payload_in_workspace = False + return None + # Determine workspace dtype and whether backend supports workspace output workspace_dtype = output_dtype - backend_supports_workspace = False - if isinstance(self.backend, TRTLLMGenFusedMoE): # TRTLLMGen specific configuration self.comm.invalid_token_expert_id = -1 workspace_dtype = torch.bfloat16 - backend_supports_workspace = self.backend.has_w4a8_mxfp4_mxfp8 - elif isinstance(self.backend, CutlassFusedMoE): - # Cutlass always supports workspace output with NVLinkOneSided - backend_supports_workspace = True - - if not backend_supports_workspace: - # Ensure payload_in_workspace is False if backend doesn't support it - self.comm.payload_in_workspace = False - return None # Calculate runtime max tokens per rank assert all_rank_num_tokens is not None, ( @@ -1022,7 +1016,6 @@ class ConfigurableMoE(MoE): # Dynamically enable payload_in_workspace for this forward pass self.comm.payload_in_workspace = True - return moe_output def _get_backend_kwargs( @@ -1096,13 +1089,18 @@ class ConfigurableMoE(MoE): # Get moe_output for NVLinkOneSided backend kwargs["moe_output"] = self._get_nvlink_onesided_moe_output( - all_rank_num_tokens, output_dtype + all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype ) # CuteDSL-specific parameters elif self.backend.__class__ == CuteDslFusedMoE: kwargs["enable_alltoall"] = self.enable_alltoall + # Get moe_output for NVLinkOneSided backend + kwargs["moe_output"] = self._get_nvlink_onesided_moe_output( + all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype + ) + # DeepGemm-specific parameters elif self.backend.__class__ == DeepGemmFusedMoE: if workspace is not None: @@ -1123,7 +1121,7 @@ class ConfigurableMoE(MoE): # Get moe_output for NVLinkOneSided backend kwargs["moe_output"] = self._get_nvlink_onesided_moe_output( - all_rank_num_tokens, output_dtype + all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype ) return kwargs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 0ecd3e3e85..2cec8a269e 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -210,6 +210,9 @@ class CuteDslFusedMoE(CutlassFusedMoE): return NVFP4CuteDslFusedMoEMethod() return super()._get_quant_method() + def supports_moe_output_in_alltoall_workspace(self): + return self.has_nvfp4 + def quantize_input(self, x: Union[torch.Tensor, Fp4QuantizedTensor], post_quant_comm: bool = True): @@ -258,6 +261,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): token_selected_experts: torch.Tensor, token_final_scales: Optional[torch.Tensor], x_sf: Optional[torch.Tensor] = None, + moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, ) -> torch.Tensor: assert self.has_nvfp4 @@ -274,6 +278,16 @@ class CuteDslFusedMoE(CutlassFusedMoE): tile_tokens_dim=tile_size, ) + if moe_output is None: + moe_output = torch.empty( + (token_final_scales.size(0), self.hidden_size), + dtype=output_dtype, + device=x.device) + else: + assert moe_output.size() == (token_final_scales.size(0), + self.hidden_size) + assert moe_output.dtype == output_dtype + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2), @@ -291,12 +305,10 @@ class CuteDslFusedMoE(CutlassFusedMoE): local_expert_offset=self.slot_start, tile_size=tile_size, ) + if self.use_fused_finalize: - output = torch.empty((token_final_scales.size(0), self.hidden_size), - dtype=output_dtype, - device=x.device) torch.ops.trtllm.moe_output_memset_inplace( - input=output, + input=moe_output, tile_idx_to_mn_limit=tile_idx_to_mn_limit, expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, @@ -313,7 +325,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): weight_scale=self.quant_scales.fc2_weight_block.view( torch.uint8), alpha=self.quant_scales.fc2_global, - output=output, + output=moe_output, tile_idx_to_group_idx=tile_idx_to_expert_idx, tile_idx_to_mn_limit=tile_idx_to_mn_limit, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, @@ -326,7 +338,6 @@ class CuteDslFusedMoE(CutlassFusedMoE): tile_size=tile_size, output_dtype=output_dtype, ) - x = output else: x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell( input=x.view(torch.float4_e2m1fn_x2), @@ -344,12 +355,13 @@ class CuteDslFusedMoE(CutlassFusedMoE): tile_size=tile_size, output_dtype=output_dtype, ) - x = torch.ops.trtllm.moe_unpermute( + torch.ops.trtllm.moe_unpermute_inplace( permuted_input=x, + output=moe_output, expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, topk_scales=token_final_scales, ) - return x + return moe_output def run_moe_fp8_block_scales( self, @@ -364,12 +376,12 @@ class CuteDslFusedMoE(CutlassFusedMoE): weight_dtype = self.w3_w1_weight.dtype ( - permuted_row_to_unpermuted_row_tensor, - permuted_token_selected_experts_tensor, - permuted_data_tensor, - expert_first_token_offset_tensor, - permuted_token_final_scales_tensor, - unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row, + permuted_token_selected_experts, + x, + expert_first_token_offset, + permuted_token_final_scales, + unpermuted_row_to_permuted_row, ) = torch.ops.trtllm.moe_permute_op( x, token_selected_experts, @@ -388,35 +400,34 @@ class CuteDslFusedMoE(CutlassFusedMoE): min_latency_mode=False, use_fp8_block_scaling=True, ) - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( - permuted_data_tensor) - h1 = cute_dsl_fp8_group_blockwise_gemm_ref( - a=act_input_fp8, + x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) + x = cute_dsl_fp8_group_blockwise_gemm_ref( + a=x, b=self.w3_w1_weight.view(weight_dtype), - a_sf=act_input_sf, + a_sf=x_sf, b_sf=self.quant_scales[0], - offset_array=expert_first_token_offset_tensor, + offset_array=expert_first_token_offset, ) - h2 = swiglu_fused_moe(h1) - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2) - h3 = cute_dsl_fp8_group_blockwise_gemm_ref( - a=act_input_fp8, + x = swiglu_fused_moe(x) + x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) + x = cute_dsl_fp8_group_blockwise_gemm_ref( + a=x, b=self.w2_weight.view(weight_dtype), - a_sf=act_input_sf, + a_sf=x_sf, b_sf=self.quant_scales[1], - offset_array=expert_first_token_offset_tensor, + offset_array=expert_first_token_offset, ) - h4 = torch.ops.trtllm.moe_finalize_scale_op( - h3, + x = torch.ops.trtllm.moe_finalize_scale_op( + x, None, # biases token_final_scales, - unpermuted_row_to_permuted_row_tensor, - permuted_row_to_unpermuted_row_tensor, + unpermuted_row_to_permuted_row, + permuted_row_to_unpermuted_row, token_selected_experts, - expert_first_token_offset_tensor, + expert_first_token_offset, enable_alltoall, - x.shape[0], # num_rows - x.shape[1], # (possibly padded) hidden_size + token_final_scales.size(0), # num_rows + self.hidden_size, # (possibly padded) hidden_size self.unpadded_hidden_size, # original hidden size self.routing_method.top_k, self.expert_size_per_partition, # num_experts_per_node @@ -425,7 +436,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): self.ep_size, self.ep_rank, ) - return h4 + return x def run_moe( self, @@ -433,6 +444,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): token_selected_experts: torch.Tensor, token_final_scales: Optional[torch.Tensor], x_sf: Optional[torch.Tensor] = None, + moe_output: Optional[torch.Tensor] = None, enable_alltoall: bool = False, ) -> torch.Tensor: """ @@ -448,6 +460,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): this represents expert slots [num_tokens, top_k] instead. token_final_scales: Final scaling factors for each token x_sf: Input scale factors (optional, for certain quantization schemes) + moe_output: Pre-allocated MoE output buffer (optional, for NVLINK one-sided backend). enable_alltoall: Whether alltoall communication is enabled. Returns: @@ -459,6 +472,7 @@ class CuteDslFusedMoE(CutlassFusedMoE): token_selected_experts=token_selected_experts, token_final_scales=token_final_scales, x_sf=x_sf, + moe_output=moe_output, enable_alltoall=enable_alltoall) elif self.has_deepseek_fp8_block_scales: return self.run_moe_fp8_block_scales( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 534c89d104..71e13e1324 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -389,6 +389,9 @@ class CutlassFusedMoE(MoE): self._weights_created = True self._check_configs() + def supports_moe_output_in_alltoall_workspace(self): + return True + def run_moe( self, x: torch.Tensor, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index ba735bc1a2..2af4bb6f85 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -354,6 +354,9 @@ class TRTLLMGenFusedMoE(MoE): return x, x_sf + def supports_moe_output_in_alltoall_workspace(self): + return self.has_w4a8_mxfp4_mxfp8 + def run_moe( self, x: torch.Tensor, diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index e415d0cc1b..e6d7797b9b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -723,6 +723,11 @@ class MoE(nn.Module): def expand_intermediate_size_per_partition(self): return self.intermediate_size_per_partition * self.intermediate_size_expand_ratio + def supports_moe_output_in_alltoall_workspace(self): + """ Supports moe_output in alltoall workspace + """ + return False + def reducescatter_or_allreduce( self, inputs,