mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9819][perf] Reuse alltoall workspace for CuteDSL MoE output (#9840)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
0b279f4ad4
commit
6fe89ea00f
@ -205,24 +205,26 @@ std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Ten
|
||||
|
||||
// Unpermute
|
||||
|
||||
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
|
||||
torch::Tensor const& topk_scales)
|
||||
void moe_unpermute_inplace(torch::Tensor const& permuted_input, torch::Tensor const& output,
|
||||
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& topk_scales)
|
||||
{
|
||||
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
|
||||
int64_t const max_num_permuted_tokens = permuted_input.size(0);
|
||||
int64_t const hidden_size = permuted_input.size(1);
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D.");
|
||||
int64_t const num_tokens = output.size(0);
|
||||
TORCH_CHECK(output.size(1) == hidden_size, "output.size(1) must be hidden_size.");
|
||||
|
||||
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
|
||||
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
|
||||
TORCH_CHECK(
|
||||
expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens.");
|
||||
int64_t const top_k = expanded_idx_to_permuted_idx.size(1);
|
||||
TORCH_CHECK(topk_scales.dim() == 2, "topk_scales must be 2D.");
|
||||
TORCH_CHECK(topk_scales.size(0) == num_tokens, "topk_scales.size(0) must be num_tokens.");
|
||||
TORCH_CHECK(topk_scales.size(1) == top_k, "topk_scales.size(1) must be top_k.");
|
||||
|
||||
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
|
||||
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
|
||||
|
||||
auto output
|
||||
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(permuted_input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_UNPERMUTE(InputType, TopKScaleType) \
|
||||
@ -253,7 +255,19 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_UNPERMUTE
|
||||
}
|
||||
|
||||
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
|
||||
torch::Tensor const& topk_scales)
|
||||
{
|
||||
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
|
||||
int64_t const hidden_size = permuted_input.size(1);
|
||||
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
|
||||
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
|
||||
|
||||
auto output
|
||||
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
|
||||
moe_unpermute_inplace(permuted_input, output, expanded_idx_to_permuted_idx, topk_scales);
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -489,6 +503,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
|
||||
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
|
||||
m.def(
|
||||
"moe_unpermute_inplace(Tensor permuted_input, Tensor(a!) output, Tensor expanded_idx_to_permuted_idx, Tensor "
|
||||
"topk_scales) -> ()");
|
||||
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
|
||||
m.def(
|
||||
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
|
||||
@ -510,6 +527,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("moe_topk_sort", &tensorrt_llm::torch_ext::moe_topk_sort);
|
||||
m.impl("moe_sort", &tensorrt_llm::torch_ext::moe_sort);
|
||||
m.impl("moe_permute", &tensorrt_llm::torch_ext::moe_permute);
|
||||
m.impl("moe_unpermute_inplace", &tensorrt_llm::torch_ext::moe_unpermute_inplace);
|
||||
m.impl("moe_unpermute", &tensorrt_llm::torch_ext::moe_unpermute);
|
||||
m.impl("moe_output_memset_inplace", &tensorrt_llm::torch_ext::moe_output_memset_inplace);
|
||||
m.impl("moe_swiglu", &tensorrt_llm::torch_ext::moe_swiglu);
|
||||
|
||||
@ -77,6 +77,9 @@ def inplace_info():
|
||||
torch.ops.trtllm.logits_bitmask.default: {
|
||||
1: "logits"
|
||||
},
|
||||
torch.ops.trtllm.moe_unpermute_inplace.default: {
|
||||
2: "output"
|
||||
},
|
||||
torch.ops.trtllm.moe_output_memset_inplace.default: {
|
||||
1: "input"
|
||||
},
|
||||
|
||||
@ -991,23 +991,17 @@ class ConfigurableMoE(MoE):
|
||||
if not isinstance(self.comm, NVLinkOneSided):
|
||||
return None
|
||||
|
||||
if not self.backend.supports_moe_output_in_alltoall_workspace():
|
||||
# Ensure payload_in_workspace is False if backend doesn't support it
|
||||
self.comm.payload_in_workspace = False
|
||||
return None
|
||||
|
||||
# Determine workspace dtype and whether backend supports workspace output
|
||||
workspace_dtype = output_dtype
|
||||
backend_supports_workspace = False
|
||||
|
||||
if isinstance(self.backend, TRTLLMGenFusedMoE):
|
||||
# TRTLLMGen specific configuration
|
||||
self.comm.invalid_token_expert_id = -1
|
||||
workspace_dtype = torch.bfloat16
|
||||
backend_supports_workspace = self.backend.has_w4a8_mxfp4_mxfp8
|
||||
elif isinstance(self.backend, CutlassFusedMoE):
|
||||
# Cutlass always supports workspace output with NVLinkOneSided
|
||||
backend_supports_workspace = True
|
||||
|
||||
if not backend_supports_workspace:
|
||||
# Ensure payload_in_workspace is False if backend doesn't support it
|
||||
self.comm.payload_in_workspace = False
|
||||
return None
|
||||
|
||||
# Calculate runtime max tokens per rank
|
||||
assert all_rank_num_tokens is not None, (
|
||||
@ -1022,7 +1016,6 @@ class ConfigurableMoE(MoE):
|
||||
|
||||
# Dynamically enable payload_in_workspace for this forward pass
|
||||
self.comm.payload_in_workspace = True
|
||||
|
||||
return moe_output
|
||||
|
||||
def _get_backend_kwargs(
|
||||
@ -1096,13 +1089,18 @@ class ConfigurableMoE(MoE):
|
||||
|
||||
# Get moe_output for NVLinkOneSided backend
|
||||
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
|
||||
all_rank_num_tokens, output_dtype
|
||||
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# CuteDSL-specific parameters
|
||||
elif self.backend.__class__ == CuteDslFusedMoE:
|
||||
kwargs["enable_alltoall"] = self.enable_alltoall
|
||||
|
||||
# Get moe_output for NVLinkOneSided backend
|
||||
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
|
||||
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# DeepGemm-specific parameters
|
||||
elif self.backend.__class__ == DeepGemmFusedMoE:
|
||||
if workspace is not None:
|
||||
@ -1123,7 +1121,7 @@ class ConfigurableMoE(MoE):
|
||||
|
||||
# Get moe_output for NVLinkOneSided backend
|
||||
kwargs["moe_output"] = self._get_nvlink_onesided_moe_output(
|
||||
all_rank_num_tokens, output_dtype
|
||||
all_rank_num_tokens=all_rank_num_tokens, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
@ -210,6 +210,9 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
return NVFP4CuteDslFusedMoEMethod()
|
||||
return super()._get_quant_method()
|
||||
|
||||
def supports_moe_output_in_alltoall_workspace(self):
|
||||
return self.has_nvfp4
|
||||
|
||||
def quantize_input(self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
post_quant_comm: bool = True):
|
||||
@ -258,6 +261,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
x_sf: Optional[torch.Tensor] = None,
|
||||
moe_output: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert self.has_nvfp4
|
||||
@ -274,6 +278,16 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
|
||||
if moe_output is None:
|
||||
moe_output = torch.empty(
|
||||
(token_final_scales.size(0), self.hidden_size),
|
||||
dtype=output_dtype,
|
||||
device=x.device)
|
||||
else:
|
||||
assert moe_output.size() == (token_final_scales.size(0),
|
||||
self.hidden_size)
|
||||
assert moe_output.dtype == output_dtype
|
||||
|
||||
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
|
||||
@ -291,12 +305,10 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
|
||||
if self.use_fused_finalize:
|
||||
output = torch.empty((token_final_scales.size(0), self.hidden_size),
|
||||
dtype=output_dtype,
|
||||
device=x.device)
|
||||
torch.ops.trtllm.moe_output_memset_inplace(
|
||||
input=output,
|
||||
input=moe_output,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
@ -313,7 +325,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(
|
||||
torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
output=output,
|
||||
output=moe_output,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
@ -326,7 +338,6 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x = output
|
||||
else:
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
@ -344,12 +355,13 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x = torch.ops.trtllm.moe_unpermute(
|
||||
torch.ops.trtllm.moe_unpermute_inplace(
|
||||
permuted_input=x,
|
||||
output=moe_output,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
topk_scales=token_final_scales,
|
||||
)
|
||||
return x
|
||||
return moe_output
|
||||
|
||||
def run_moe_fp8_block_scales(
|
||||
self,
|
||||
@ -364,12 +376,12 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
|
||||
(
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
permuted_token_selected_experts_tensor,
|
||||
permuted_data_tensor,
|
||||
expert_first_token_offset_tensor,
|
||||
permuted_token_final_scales_tensor,
|
||||
unpermuted_row_to_permuted_row_tensor,
|
||||
permuted_row_to_unpermuted_row,
|
||||
permuted_token_selected_experts,
|
||||
x,
|
||||
expert_first_token_offset,
|
||||
permuted_token_final_scales,
|
||||
unpermuted_row_to_permuted_row,
|
||||
) = torch.ops.trtllm.moe_permute_op(
|
||||
x,
|
||||
token_selected_experts,
|
||||
@ -388,35 +400,34 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
min_latency_mode=False,
|
||||
use_fp8_block_scaling=True,
|
||||
)
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
permuted_data_tensor)
|
||||
h1 = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=act_input_fp8,
|
||||
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w3_w1_weight.view(weight_dtype),
|
||||
a_sf=act_input_sf,
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[0],
|
||||
offset_array=expert_first_token_offset_tensor,
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
h2 = swiglu_fused_moe(h1)
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2)
|
||||
h3 = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=act_input_fp8,
|
||||
x = swiglu_fused_moe(x)
|
||||
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w2_weight.view(weight_dtype),
|
||||
a_sf=act_input_sf,
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[1],
|
||||
offset_array=expert_first_token_offset_tensor,
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
h4 = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
h3,
|
||||
x = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
x,
|
||||
None, # biases
|
||||
token_final_scales,
|
||||
unpermuted_row_to_permuted_row_tensor,
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
unpermuted_row_to_permuted_row,
|
||||
permuted_row_to_unpermuted_row,
|
||||
token_selected_experts,
|
||||
expert_first_token_offset_tensor,
|
||||
expert_first_token_offset,
|
||||
enable_alltoall,
|
||||
x.shape[0], # num_rows
|
||||
x.shape[1], # (possibly padded) hidden_size
|
||||
token_final_scales.size(0), # num_rows
|
||||
self.hidden_size, # (possibly padded) hidden_size
|
||||
self.unpadded_hidden_size, # original hidden size
|
||||
self.routing_method.top_k,
|
||||
self.expert_size_per_partition, # num_experts_per_node
|
||||
@ -425,7 +436,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
self.ep_size,
|
||||
self.ep_rank,
|
||||
)
|
||||
return h4
|
||||
return x
|
||||
|
||||
def run_moe(
|
||||
self,
|
||||
@ -433,6 +444,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
x_sf: Optional[torch.Tensor] = None,
|
||||
moe_output: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -448,6 +460,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
this represents expert slots [num_tokens, top_k] instead.
|
||||
token_final_scales: Final scaling factors for each token
|
||||
x_sf: Input scale factors (optional, for certain quantization schemes)
|
||||
moe_output: Pre-allocated MoE output buffer (optional, for NVLINK one-sided backend).
|
||||
enable_alltoall: Whether alltoall communication is enabled.
|
||||
|
||||
Returns:
|
||||
@ -459,6 +472,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
x_sf=x_sf,
|
||||
moe_output=moe_output,
|
||||
enable_alltoall=enable_alltoall)
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
return self.run_moe_fp8_block_scales(
|
||||
|
||||
@ -389,6 +389,9 @@ class CutlassFusedMoE(MoE):
|
||||
self._weights_created = True
|
||||
self._check_configs()
|
||||
|
||||
def supports_moe_output_in_alltoall_workspace(self):
|
||||
return True
|
||||
|
||||
def run_moe(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -354,6 +354,9 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
|
||||
return x, x_sf
|
||||
|
||||
def supports_moe_output_in_alltoall_workspace(self):
|
||||
return self.has_w4a8_mxfp4_mxfp8
|
||||
|
||||
def run_moe(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -723,6 +723,11 @@ class MoE(nn.Module):
|
||||
def expand_intermediate_size_per_partition(self):
|
||||
return self.intermediate_size_per_partition * self.intermediate_size_expand_ratio
|
||||
|
||||
def supports_moe_output_in_alltoall_workspace(self):
|
||||
""" Supports moe_output in alltoall workspace
|
||||
"""
|
||||
return False
|
||||
|
||||
def reducescatter_or_allreduce(
|
||||
self,
|
||||
inputs,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user