From 6b251cc7fa8fe2f795e47222912ec2046b0b2fce Mon Sep 17 00:00:00 2001 From: Bo Li <22713281+bobboli@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:55:07 +0800 Subject: [PATCH] [TRTLLM-9390][chore] Add Fake OPs for One-Sided AlltoAll. (#11002) Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --- .../_torch/custom_ops/cpp_custom_ops.py | 78 +++++++++++++++++++ .../_torch/thop/parallel/test_custom_ops.py | 6 -- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 73ee7ee3d6..3a3ee1238b 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -327,6 +327,84 @@ def _register_fake(): outputs.append(output_tensor) return outputs + @torch.library.register_fake("trtllm::moe_a2a_dispatch") + def _( + token_selected_experts: torch.Tensor, + input_payloads: List[torch.Tensor], + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + num_experts: int, + eplb_local_stats: Optional[torch.Tensor] = None, + ) -> Tuple[List[torch.Tensor], int, torch.Tensor]: + recv_tensors: List[torch.Tensor] = [] + for payload in input_payloads: + elements_per_token = payload.shape[1] + recv_tensors.append( + payload.new_empty( + (ep_size, runtime_max_tokens_per_rank, elements_per_token))) + + if eplb_local_stats is None: + eplb_gathered_stats = workspace.new_empty((0, ), dtype=torch.int32) + else: + eplb_gathered_stats = workspace.new_empty( + (ep_size, eplb_local_stats.shape[0]), dtype=torch.int32) + + combine_payload_offset = 0 + return recv_tensors, combine_payload_offset, eplb_gathered_stats + + @torch.library.register_fake("trtllm::moe_a2a_combine") + def _( + payload: torch.Tensor, + local_num_tokens: int, + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + combine_payload_offset: int, + payload_in_workspace: bool, + ) -> torch.Tensor: + return payload.new_empty((local_num_tokens, payload.shape[2])) + + @torch.library.register_fake("trtllm::moe_a2a_initialize") + def _( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + max_num_tokens_per_rank: int, + eplb_stats_num_experts: Optional[int] = None, + ) -> torch.Tensor: + return torch.empty((10, ), dtype=torch.int64, device="cpu") + + @torch.library.register_fake("trtllm::moe_a2a_sanitize_expert_ids") + def _( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + metainfo: torch.Tensor, + ep_rank: int, + invalid_expert_id: int, + ) -> None: + return None + + @torch.library.register_fake("trtllm::moe_a2a_get_combine_payload_tensor") + def _( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + runtime_max_tokens_per_rank: int, + combine_payload_offset: int, + out_dtype: torch.dtype, + hidden_size: int, + ) -> torch.Tensor: + return workspace.new_empty( + (ep_size * runtime_max_tokens_per_rank, hidden_size), + dtype=out_dtype) + @torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank") def _(ep_size: int): return 0 diff --git a/tests/unittest/_torch/thop/parallel/test_custom_ops.py b/tests/unittest/_torch/thop/parallel/test_custom_ops.py index 5d65b83a75..4d8ca8ba86 100644 --- a/tests/unittest/_torch/thop/parallel/test_custom_ops.py +++ b/tests/unittest/_torch/thop/parallel/test_custom_ops.py @@ -99,12 +99,6 @@ def test_register_fake(custom_ops): "trtllm::e4m3_mxe2m1_block_scale_moe_runner", "trtllm::mxe4m3_mxe2m1_block_scale_moe_runner", "trtllm::mxfp8_quantize", - "trtllm::moe_a2a_dispatch", - "trtllm::moe_a2a_combine", - "trtllm::moe_a2a_initialize", - "trtllm::moe_a2a_get_combine_payload_tensor", - "trtllm::moe_a2a_sanitize_expert_ids", - "trtllm::moe_a2a_get_aux_data_size", } ops_missing_fake_impl = []