[TRTLLM-9390][chore] Add Fake OPs for One-Sided AlltoAll. (#11002)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2026-01-27 15:55:07 +08:00 committed by GitHub
parent 93ae8a14ab
commit 6b251cc7fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 6 deletions

View File

@ -327,6 +327,84 @@ def _register_fake():
outputs.append(output_tensor)
return outputs
@torch.library.register_fake("trtllm::moe_a2a_dispatch")
def _(
token_selected_experts: torch.Tensor,
input_payloads: List[torch.Tensor],
workspace: torch.Tensor,
metainfo: torch.Tensor,
runtime_max_tokens_per_rank: int,
ep_rank: int,
ep_size: int,
top_k: int,
num_experts: int,
eplb_local_stats: Optional[torch.Tensor] = None,
) -> Tuple[List[torch.Tensor], int, torch.Tensor]:
recv_tensors: List[torch.Tensor] = []
for payload in input_payloads:
elements_per_token = payload.shape[1]
recv_tensors.append(
payload.new_empty(
(ep_size, runtime_max_tokens_per_rank, elements_per_token)))
if eplb_local_stats is None:
eplb_gathered_stats = workspace.new_empty((0, ), dtype=torch.int32)
else:
eplb_gathered_stats = workspace.new_empty(
(ep_size, eplb_local_stats.shape[0]), dtype=torch.int32)
combine_payload_offset = 0
return recv_tensors, combine_payload_offset, eplb_gathered_stats
@torch.library.register_fake("trtllm::moe_a2a_combine")
def _(
payload: torch.Tensor,
local_num_tokens: int,
workspace: torch.Tensor,
metainfo: torch.Tensor,
runtime_max_tokens_per_rank: int,
ep_rank: int,
ep_size: int,
top_k: int,
combine_payload_offset: int,
payload_in_workspace: bool,
) -> torch.Tensor:
return payload.new_empty((local_num_tokens, payload.shape[2]))
@torch.library.register_fake("trtllm::moe_a2a_initialize")
def _(
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
max_num_tokens_per_rank: int,
eplb_stats_num_experts: Optional[int] = None,
) -> torch.Tensor:
return torch.empty((10, ), dtype=torch.int64, device="cpu")
@torch.library.register_fake("trtllm::moe_a2a_sanitize_expert_ids")
def _(
expert_ids: torch.Tensor,
workspace: torch.Tensor,
metainfo: torch.Tensor,
ep_rank: int,
invalid_expert_id: int,
) -> None:
return None
@torch.library.register_fake("trtllm::moe_a2a_get_combine_payload_tensor")
def _(
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
runtime_max_tokens_per_rank: int,
combine_payload_offset: int,
out_dtype: torch.dtype,
hidden_size: int,
) -> torch.Tensor:
return workspace.new_empty(
(ep_size * runtime_max_tokens_per_rank, hidden_size),
dtype=out_dtype)
@torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank")
def _(ep_size: int):
return 0

View File

@ -99,12 +99,6 @@ def test_register_fake(custom_ops):
"trtllm::e4m3_mxe2m1_block_scale_moe_runner",
"trtllm::mxe4m3_mxe2m1_block_scale_moe_runner",
"trtllm::mxfp8_quantize",
"trtllm::moe_a2a_dispatch",
"trtllm::moe_a2a_combine",
"trtllm::moe_a2a_initialize",
"trtllm::moe_a2a_get_combine_payload_tensor",
"trtllm::moe_a2a_sanitize_expert_ids",
"trtllm::moe_a2a_get_aux_data_size",
}
ops_missing_fake_impl = []