[None][feat] Add alltoall to trtllm-gen MoE backend. (#8481)

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
Bo Li 2025-10-21 12:42:54 +08:00 committed by GitHub
parent ab4b9966b2
commit ebb62e17d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,11 @@
import os
from functools import cached_property
from typing import Dict, List, Optional, Union
import torch
from torch import nn
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
from tensorrt_llm._utils import get_sm_version
from ...custom_ops.trtllm_gen_custom_ops import \
@ -106,10 +109,28 @@ class TRTLLMGenFusedMoE(MoE):
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
if self.enable_alltoall:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
self._weights_created = False
if not model_config.skip_create_weights_in_init:
self.create_weights()
@cached_property
def enable_alltoall(self):
mapping = self.mapping
routing_experts = self.routing_method.experts_per_token
return (mapping.moe_ep_size > routing_experts
and mapping.enable_attention_dp and mapping.tp_size > 1
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
and MnnvlMemory.supports_mnnvl())
def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
@ -175,6 +196,48 @@ class TRTLLMGenFusedMoE(MoE):
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def _quantize_for_post_quant_comm(self, x):
"""Quantize inputs prior to post-communication (alltoall/allgather).
Returns: (x, x_sf, x_row, x_col)
"""
x_row = x.shape[0]
x_col = x.shape[1]
x_sf = None
if self.has_w4a8_mxfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant[0])
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_nvfp4:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x_row = x.shape[0]
x_col = x.shape[1] * 2
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x_row = x.shape[0]
x_col = x.shape[1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
False)
elif self.has_w4a8_mxfp4_mxfp8:
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
# No change required before communication
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
elif self.has_w4a8_nvfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, 1.0 / self.fc31_input_scale)
else:
raise ValueError(
f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}"
)
return x, x_sf, x_row, x_col
def forward_impl(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
@ -202,55 +265,80 @@ class TRTLLMGenFusedMoE(MoE):
topk_group = None
routed_scaling_factor = None
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
run_post_quant_allgather = (self.use_dp and self.parallel_size > 1
and not self.enable_alltoall)
post_quant_comm = run_post_quant_allgather or self.enable_alltoall
x_sf = None
token_selected_experts = None
token_final_scales = None
x_row = x.shape[0]
x_col = x.shape[1]
if run_post_quant_allgather:
# apply routing
token_count = x.shape[0]
alltoall_info = None
if post_quant_comm:
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
token_final_scales = token_final_scales.to(torch.bfloat16)
assert token_final_scales.dtype == torch.bfloat16
assert token_selected_experts.dtype == torch.int32
# quantize inputs
if self.has_w4a8_mxfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant[0])
# Update x_row and x_col to the padded shape
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_nvfp4:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x_row = x.shape[0]
# note: we use uint8 to store 2 fp4 values
x_col = x.shape[1] * 2
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x_row = x.shape[0]
x_col = x.shape[1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size,
False, False)
elif self.has_w4a8_mxfp4_mxfp8:
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
# Update x_row and x_col to the padded shape
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
else:
raise ValueError(
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
)
token_selected_experts = token_selected_experts.to(torch.int32)
if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)
#allgather for attention DP
x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)
if self.enable_alltoall:
assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall"
max_num_token = max(
all_rank_num_tokens) if all_rank_num_tokens else token_count
if token_final_scales is None:
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
else:
token_final_scales = token_final_scales.to(torch.float32)
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_experts,
None,
self.alltoall_prepare_workspace,
max_num_token,
self.ep_rank,
self.ep_size,
self.num_experts,
self.num_slots,
top_k,
)
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))
x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
[x, x_sf, token_selected_experts, token_final_scales],
alltoall_info,
self.alltoall_workspace,
self.ep_rank,
self.ep_size,
)
torch.ops.trtllm.memset_expert_ids(
token_selected_experts,
alltoall_info.recv_rank_count_cumsum,
max_num_token,
top_k,
self.num_slots,
self.ep_size,
)
if x_sf is not None:
x_sf = x_sf.flatten()
if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)
elif run_post_quant_allgather:
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))
@ -265,6 +353,9 @@ class TRTLLMGenFusedMoE(MoE):
if x_sf is not None:
x_sf = x_sf.flatten()
router_logits_arg = router_logits if not post_quant_comm else None
routing_bias_arg = routing_bias if not post_quant_comm else None
# TODO: since routing kernel is integrated into moe_runner for fp8,
# here we just route the I/Os for moe_runner
if self.has_deepseek_fp8_block_scales:
@ -272,8 +363,8 @@ class TRTLLMGenFusedMoE(MoE):
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x_val,
x_scale,
self.w3_w1_weight,
@ -297,7 +388,7 @@ class TRTLLMGenFusedMoE(MoE):
scale_factor_use_ue8m0 = False
is_scale_factor_swizzled = False # use linear layout here
if not run_post_quant_allgather:
if not post_quant_comm:
hidden_states_fp4, hidden_states_scale_linear_fp4 = (
torch.ops.trtllm.fp4_quantize(
x,
@ -310,8 +401,8 @@ class TRTLLMGenFusedMoE(MoE):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
hidden_states_fp4,
hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn),
self.w3_w1_weight,
@ -343,7 +434,7 @@ class TRTLLMGenFusedMoE(MoE):
final_hidden_states = outputs[0]
elif self.has_w4a16_mxfp4:
assert x.dtype == torch.bfloat16
if not run_post_quant_allgather:
if not post_quant_comm:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
else:
@ -352,8 +443,8 @@ class TRTLLMGenFusedMoE(MoE):
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
-2] // 2
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x,
self.w3_w1_weight,
self.w3_w1_weight_scale,
@ -383,15 +474,15 @@ class TRTLLMGenFusedMoE(MoE):
hidden_size].contiguous()
elif self.has_w4a8_nvfp4_fp8:
if not run_post_quant_allgather:
if not post_quant_comm:
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, 1.0 / self.fc31_input_scale)
else:
hidden_states_fp8 = x
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
router_logits,
routing_bias,
router_logits_arg,
routing_bias_arg,
hidden_states_fp8,
self.w3_w1_weight,
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
@ -423,7 +514,7 @@ class TRTLLMGenFusedMoE(MoE):
final_hidden_states = outputs[0]
elif self.has_w4a8_mxfp4_fp8:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
if not run_post_quant_allgather:
if not post_quant_comm:
x = torch.nn.functional.pad(x, (0, pad_size))
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_gate_dequant[0])
@ -433,8 +524,8 @@ class TRTLLMGenFusedMoE(MoE):
-2] // 2
final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x,
self.w3_w1_weight,
self.w3_w1_weight_scale,
@ -466,7 +557,7 @@ class TRTLLMGenFusedMoE(MoE):
final_hidden_states = final_hidden_states[:, :self.
hidden_size].contiguous()
elif self.has_w4a8_mxfp4_mxfp8:
if not run_post_quant_allgather:
if not post_quant_comm:
# TRTLLM-Gen uses linear SF layout for the mxfp8 input.
mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
@ -477,8 +568,8 @@ class TRTLLMGenFusedMoE(MoE):
-2] // 2
final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
mxfp8_x,
sf,
self.w3_w1_weight,
@ -511,6 +602,18 @@ class TRTLLMGenFusedMoE(MoE):
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
)
# Combine results if using alltoall
if self.enable_alltoall and alltoall_info is not None:
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
final_hidden_states,
alltoall_info,
self.alltoall_workspace,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
top_k=top_k,
token_count=token_count,
)
final_hidden_states = self.reducescatter_or_allreduce(
final_hidden_states,
all_rank_num_tokens=all_rank_num_tokens,