mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-9390][chore] Add Fake OPs for One-Sided AlltoAll. (#11002)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
93ae8a14ab
commit
6b251cc7fa
@ -327,6 +327,84 @@ def _register_fake():
|
||||
outputs.append(output_tensor)
|
||||
return outputs
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_a2a_dispatch")
|
||||
def _(
|
||||
token_selected_experts: torch.Tensor,
|
||||
input_payloads: List[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
metainfo: torch.Tensor,
|
||||
runtime_max_tokens_per_rank: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
eplb_local_stats: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[List[torch.Tensor], int, torch.Tensor]:
|
||||
recv_tensors: List[torch.Tensor] = []
|
||||
for payload in input_payloads:
|
||||
elements_per_token = payload.shape[1]
|
||||
recv_tensors.append(
|
||||
payload.new_empty(
|
||||
(ep_size, runtime_max_tokens_per_rank, elements_per_token)))
|
||||
|
||||
if eplb_local_stats is None:
|
||||
eplb_gathered_stats = workspace.new_empty((0, ), dtype=torch.int32)
|
||||
else:
|
||||
eplb_gathered_stats = workspace.new_empty(
|
||||
(ep_size, eplb_local_stats.shape[0]), dtype=torch.int32)
|
||||
|
||||
combine_payload_offset = 0
|
||||
return recv_tensors, combine_payload_offset, eplb_gathered_stats
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_a2a_combine")
|
||||
def _(
|
||||
payload: torch.Tensor,
|
||||
local_num_tokens: int,
|
||||
workspace: torch.Tensor,
|
||||
metainfo: torch.Tensor,
|
||||
runtime_max_tokens_per_rank: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
top_k: int,
|
||||
combine_payload_offset: int,
|
||||
payload_in_workspace: bool,
|
||||
) -> torch.Tensor:
|
||||
return payload.new_empty((local_num_tokens, payload.shape[2]))
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_a2a_initialize")
|
||||
def _(
|
||||
workspace: torch.Tensor,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
max_num_tokens_per_rank: int,
|
||||
eplb_stats_num_experts: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((10, ), dtype=torch.int64, device="cpu")
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_a2a_sanitize_expert_ids")
|
||||
def _(
|
||||
expert_ids: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
metainfo: torch.Tensor,
|
||||
ep_rank: int,
|
||||
invalid_expert_id: int,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_a2a_get_combine_payload_tensor")
|
||||
def _(
|
||||
workspace: torch.Tensor,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
runtime_max_tokens_per_rank: int,
|
||||
combine_payload_offset: int,
|
||||
out_dtype: torch.dtype,
|
||||
hidden_size: int,
|
||||
) -> torch.Tensor:
|
||||
return workspace.new_empty(
|
||||
(ep_size * runtime_max_tokens_per_rank, hidden_size),
|
||||
dtype=out_dtype)
|
||||
|
||||
@torch.library.register_fake("trtllm::get_moe_commworkspace_size_per_rank")
|
||||
def _(ep_size: int):
|
||||
return 0
|
||||
|
||||
@ -99,12 +99,6 @@ def test_register_fake(custom_ops):
|
||||
"trtllm::e4m3_mxe2m1_block_scale_moe_runner",
|
||||
"trtllm::mxe4m3_mxe2m1_block_scale_moe_runner",
|
||||
"trtllm::mxfp8_quantize",
|
||||
"trtllm::moe_a2a_dispatch",
|
||||
"trtllm::moe_a2a_combine",
|
||||
"trtllm::moe_a2a_initialize",
|
||||
"trtllm::moe_a2a_get_combine_payload_tensor",
|
||||
"trtllm::moe_a2a_sanitize_expert_ids",
|
||||
"trtllm::moe_a2a_get_aux_data_size",
|
||||
}
|
||||
|
||||
ops_missing_fake_impl = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user