Clean up allreduce op in Deepseek V3 model. (#3829)

* Replace deepseek_allreduce op with the new unified allreduce op and moe_allreduce op.
* Minor revision of moe_allreduce op argument names.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-05-01 07:56:36 +08:00 committed by GitHub
parent b40f351b7a
commit 9cc5922a0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 383 additions and 377 deletions

View File

@ -273,12 +273,11 @@ private:
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
{
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int size = input.numel();
int hidden_size = input.size(-1);
torch::Tensor output = torch::empty_like(input);
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
torch::Tensor norm_out = torch::empty_like(input);
@ -815,14 +814,14 @@ std::vector<torch::Tensor> allreduce(torch::Tensor input, torch::optional<torch:
// residual [m, hidden_dim]
// norm_weight [hidden_dim]
// moe_reduction_device_num_experts [1]
// moe_reduction_scale_input [global_num_experts, m]
// moe_reduction_active_experts_token_input [device_num_experts, m, hidden_dim]
// moe_reduction_token_input [m, hidden_dim]
std::vector<torch::Tensor> moe_allreduce(torch::Tensor residual, torch::Tensor norm_weight,
torch::Tensor moe_reduction_device_num_experts, torch::Tensor moe_reduction_scale_input,
torch::Tensor moe_reduction_active_experts_token_input, torch::Tensor moe_reduction_token_input,
torch::optional<torch::Tensor> workspace, int64_t const rank, int64_t const nranks, double const eps)
// device_num_experts [1]
// scale_input [global_num_experts, m]
// active_experts_token_input [device_num_experts, m, hidden_dim]
// token_input [m, hidden_dim]
std::vector<torch::Tensor> moe_allreduce(torch::Tensor const& residual, torch::Tensor const& norm_weight,
torch::Tensor const& device_num_experts, torch::Tensor const& scale_input,
torch::Tensor const& active_experts_token_input, torch::Tensor const& token_input, torch::Tensor workspace,
int64_t const rank, int64_t const nranks, double const eps)
{
auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams();
@ -833,14 +832,13 @@ std::vector<torch::Tensor> moe_allreduce(torch::Tensor residual, torch::Tensor n
allreduce_fusion_params.nranks = static_cast<int>(nranks);
allreduce_fusion_params.rank = static_cast<int>(rank);
allreduce_fusion_params.dtype
= tensorrt_llm::runtime::TorchUtils::dataType(moe_reduction_token_input.scalar_type());
allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(token_input.scalar_type());
// size: num_token * hidden_dim
allreduce_fusion_params.size = static_cast<int>(moe_reduction_token_input.numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(moe_reduction_active_experts_token_input.size(-1));
allreduce_fusion_params.size = static_cast<int>(token_input.numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(active_experts_token_input.size(-1));
// workspace: AR scratch space
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_fusion_params.rms_gamma = norm_weight.data_ptr();
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
@ -850,15 +848,13 @@ std::vector<torch::Tensor> moe_allreduce(torch::Tensor residual, torch::Tensor n
// MOE Reduction specific params
allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr
allreduce_fusion_params.moe_reduction_device_num_experts
= static_cast<int*>(moe_reduction_device_num_experts.data_ptr());
allreduce_fusion_params.moe_reduction_scale_input = static_cast<float*>(moe_reduction_scale_input.data_ptr());
allreduce_fusion_params.moe_reduction_active_experts_token_input
= moe_reduction_active_experts_token_input.data_ptr();
allreduce_fusion_params.moe_reduction_token_input = moe_reduction_token_input.data_ptr();
allreduce_fusion_params.moe_reduction_device_num_experts = static_cast<int*>(device_num_experts.data_ptr());
allreduce_fusion_params.moe_reduction_scale_input = static_cast<float*>(scale_input.data_ptr());
allreduce_fusion_params.moe_reduction_active_experts_token_input = active_experts_token_input.data_ptr();
allreduce_fusion_params.moe_reduction_token_input = token_input.data_ptr();
// output tensors
torch::Tensor norm_out = torch::empty_like(moe_reduction_token_input);
torch::Tensor norm_out = torch::empty_like(token_input);
torch::Tensor residual_out = torch::empty_like(residual);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
@ -874,15 +870,29 @@ std::vector<torch::Tensor> moe_allreduce(torch::Tensor residual, torch::Tensor n
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"allreduce(Tensor input, Tensor? residual, Tensor? norm_weight, Tensor? scale, Tensor? bias, Tensor? "
"workspace, int[] group, int "
"strategy, int op, float eps) -> Tensor[]");
"allreduce("
"Tensor input,"
"Tensor? residual,"
"Tensor? norm_weight,"
"Tensor? scale,"
"Tensor? bias,"
"Tensor? workspace,"
"int[] group,"
"int strategy,"
"int op,"
"float eps) -> Tensor[]");
m.def(
"moe_allreduce(Tensor residual, Tensor norm_weight, Tensor "
"moe_reduction_device_num_experts, "
"Tensor moe_reduction_scale_input, Tensor moe_reduction_active_experts_token_input, Tensor "
"moe_reduction_token_input, Tensor? workspace, "
"int rank, int nranks, float eps) -> Tensor[]");
"moe_allreduce("
"Tensor residual,"
"Tensor norm_weight,"
"Tensor device_num_experts,"
"Tensor scale_input,"
"Tensor active_experts_token_input,"
"Tensor token_input,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

View File

@ -53,10 +53,10 @@ def _register_fake():
return [torch.empty_like(input)]
@torch.library.register_fake("trtllm::moe_allreduce")
def _(residual, norm_weight, moe_reduction_device_num_experts,
moe_reduction_scale_input, moe_reduction_active_experts_token_input,
moe_reduction_token_input, workspace, rank, nranks, eps):
norm_out = torch.empty_like(moe_reduction_token_input)
def _(residual, norm_weight, device_num_experts, scale_input,
active_experts_token_input, token_input, workspace, rank, nranks,
eps):
norm_out = torch.empty_like(token_input)
residual_out = torch.empty_like(residual)
return [norm_out, residual_out]

View File

@ -1,11 +1,10 @@
from .communicator import Distributed, MPIDist, PPComm, TorchDist
from .ops import (AllReduce, AllReduceFusionOp, AllReduceParams,
AllReduceStrategy, DeepseekAllReduce, allgather,
AllReduceStrategy, DeepseekAllReduce, MoEAllReduce, allgather,
reducescatter, userbuffers_allreduce_finalize)
__all__ = [
"allgather",
"allreduce",
"reducescatter",
"userbuffers_allreduce_finalize",
"AllReduce",
@ -13,6 +12,7 @@ __all__ = [
"AllReduceFusionOp",
"AllReduceStrategy",
"DeepseekAllReduce",
"MoEAllReduce",
"TorchDist",
"PPComm",
"MPIDist",

View File

@ -132,6 +132,10 @@ class AllReduce(nn.Module):
- RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
Note:
For the reference implementation for each pattern, please refer to the following unit test:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py
"""
self.mapping = mapping
@ -196,6 +200,67 @@ class AllReduce(nn.Module):
return output if len(output) > 1 else output[0]
class MoEAllReduce(nn.Module):
def __init__(self, mapping: Mapping):
"""
MoEAllReduce is a module that performs a specific fused MoE reduction
followed by a regular AR + RMS norm.
Args:
mapping (Mapping): The parallel mapping config.
Notes:
Support pattern: MoE Reduction + Add + AR + ADD_RMS, see this torch reference implementation:
expert_reduction = torch.sum(active_experts_token_input *
scale.unsqueeze(-1),
dim=0)
output_add = expert_reduction + shared_expert_output
output_residual = output_add + residual
output_hidden_states = rms_norm(output_residual, norm_weight, eps)
"""
super().__init__()
self.mapping = mapping
self.workspace = get_allreduce_workspace(self.mapping)
def forward(
self,
residual: torch.Tensor,
norm_weight: torch.Tensor,
device_num_experts: torch.Tensor,
scale_input: torch.Tensor,
active_experts_token_input: torch.Tensor,
token_input: torch.Tensor,
eps: float,
) -> torch.Tensor:
"""
Args:
residual: residual tensor
norm_weight: RMS norm weight
device_num_experts: number of experts per device
scale_input: experts to token score
active_experts_token_input: per token per expert input
token_input: per token input, shared expert output
eps: epsilon for RMSNorm
Output:
hidden_states: hidden_states of the model
residual: residual tensor
"""
return torch.ops.trtllm.moe_allreduce(
residual=residual,
norm_weight=norm_weight,
device_num_experts=device_num_experts,
scale_input=scale_input,
active_experts_token_input=active_experts_token_input,
token_input=token_input,
workspace=self.workspace,
rank=self.mapping.tp_rank,
nranks=self.mapping.tp_size,
eps=eps,
)
class DeepseekAllReduce(nn.Module):
def __init__(self, mapping: Mapping):

View File

@ -45,7 +45,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
DeepseekAllReduce, allgather)
MoEAllReduce, allgather)
from ..model_config import ModelConfig
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
from ..modules.attention import MLA
@ -392,7 +392,7 @@ class Deepseekv3MoE(nn.Module):
overridden_tp_size=shared_tp_size,
reduce_output=False)
self.all_reduce = AllReduce(self.mapping)
self.allreduce = AllReduce(self.mapping)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
@ -516,7 +516,7 @@ class Deepseekv3MoE(nn.Module):
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if not self.use_dp and self.mapping.tp_size > 1:
final_hidden_states = self.all_reduce(
final_hidden_states = self.allreduce(
final_hidden_states,
all_reduce_params=final_all_reduce_params)
@ -608,17 +608,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.layer_idx = layer_idx
self.all_reduce = AllReduce(self.mapping)
self.allreduce = AllReduce(self.mapping)
self.moe_allreduce = MoEAllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None
self.deepseek_allreduce_disabled = os.environ.get(
"TRTLLM_DEEPSEEK_ALLREDUCE_FUSION_DISABLED", "0") == "1"
if mapping.is_multi_node():
self.deepseek_allreduce_disabled = True
if not self.deepseek_allreduce_disabled:
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
def _compute_mlp_tp_size(self, intermediate_size: int,
block_size: int) -> int:
"""
@ -675,19 +668,25 @@ class DeepseekV3DecoderLayer(DecoderLayer):
**kwargs,
)
# deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
0) > 128
min_latency_mode = self._enable_latency_mode(
hidden_states.size(0)) and not using_prev_fusion
min_latency_mode = self._enable_latency_mode(hidden_states.size(0))
hidden_states_fp4 = None
if self.fusion_config.PRE_MOE_FUSION:
# Custom AR Fusion for DeepseekV3
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
if min_latency_mode:
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.experts.fc31_input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
))
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
hidden_states_sf)
else:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
@ -695,52 +694,17 @@ class DeepseekV3DecoderLayer(DecoderLayer):
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
else:
if min_latency_mode:
hidden_states, hidden_states_act, hidden_states_sf, residual = self.deepseek_allreduce(
hidden_states,
[
residual, self.post_attention_layernorm.weight,
self.mlp.experts.fc31_input_scale
],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
)
hidden_states_fp4 = Fp4QuantizedTensor(
hidden_states_act, hidden_states_sf)
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.post_attention_layernorm.weight],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
elif self.fusion_config.PRE_MLP_FUSION:
# Custom AR Fusion for DeepseekV3 with quant_fp4
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
hidden_states, self.mlp.gate_up_proj.input_scale,
self.mlp.gate_up_proj.scaling_vector_size, False)
else:
act_fp4, act_sf, residual = self.deepseek_allreduce(
hidden_states,
[
residual, self.post_attention_layernorm.weight,
self.mlp.gate_up_proj.input_scale
],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
)
act_fp4, act_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.gate_up_proj.input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
))
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(
@ -769,62 +733,39 @@ class DeepseekV3DecoderLayer(DecoderLayer):
)
if self.fusion_config.POST_MOE_FUSION:
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
if min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
activated_expert_global_ids = hidden_states[4]
if min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
hidden_states, residual = self.deepseek_allreduce(
hidden_states_activated_experts, # not used
[
residual, self.next_layer_layernorm.weight,
num_activated_experts_per_node,
experts_to_token_score,
hidden_states_activated_experts, shared_output,
activated_expert_global_ids
],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM,
)
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.next_layer_layernorm.weight],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
elif self.fusion_config.POST_MLP_FUSION:
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.next_layer_layernorm.weight],
self.next_layer_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
hidden_states, residual = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
elif self.fusion_config.POST_MLP_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
@ -878,9 +819,6 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
) -> Tuple[torch.Tensor, torch.Tensor]:
# deepseek allreduce kernel is better when m < 512
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
0) >= 512
inputs_embeds = self.enorm(embed_tokens(input_ids))
hidden_states = self.hnorm(hidden_states)
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
@ -902,24 +840,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
# MTP Layer Must have sparse MOE
if self.fusion_config.PRE_MOE_FUSION:
# Custom AR Fusion for DeepseekV3
if using_prev_fusion:
# Custom AR Fusion for DeepseekV3
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.post_attention_layernorm.weight],
self.post_attention_layernorm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
@ -933,22 +861,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
)
if self.fusion_config.POST_MOE_FUSION:
if using_prev_fusion:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.shared_head.norm.weight,
eps=self.shared_head.norm.variance_epsilon,
))
else:
hidden_states, residual = self.deepseek_allreduce(
hidden_states,
[residual, self.shared_head.norm.weight],
self.shared_head.norm.variance_epsilon,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
)
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.shared_head.norm.weight,
eps=self.shared_head.norm.variance_epsilon,
))
else:
hidden_states, _ = self.shared_head.norm(hidden_states, residual)

View File

@ -26,7 +26,7 @@ from utils.util import skip_pre_blackwell
import tensorrt_llm
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams)
AllReduceParams, MoEAllReduce)
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.mapping import Mapping
@ -72,6 +72,21 @@ def run_single_rank(tensor_parallel_size, single_rank_forward_func, input,
return True
def run_moe_single_rank(tensor_parallel_size, single_rank_forward_func,
token_input, residual, active_experts_token_input,
scale, l0_weight):
rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(rank)
try:
single_rank_forward_func(token_input, residual,
active_experts_token_input, scale,
tensor_parallel_size, rank, l0_weight)
except Exception:
traceback.print_exc()
raise
return True
@torch.inference_mode()
def run_allreduce_op(x: torch.Tensor, residual: torch.Tensor, hidden_size: int,
dtype: torch.dtype, tensor_parallel_size: int,
@ -238,28 +253,29 @@ def run_allreduce_op(x: torch.Tensor, residual: torch.Tensor, hidden_size: int,
assert mismatch_percentage < 0.01, f"Large mismatched elements encountered"
@skip_pre_blackwell
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs for this test")
@pytest.mark.parametrize("seq_len", [16, 256], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("hidden_size", [128, 7168],
ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("fusion_op", [
AllReduceFusionOp.NONE,
AllReduceFusionOp.RESIDUAL_RMS_NORM,
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8,
AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
],
ids=[
"none",
"residual_rms_norm",
"residual_rms_norm_quant_fp8",
"residual_rms_norm_out_quant_fp8",
"residual_rms_norm_quant_nvfp4",
"residual_rms_norm_out_quant_nvfp4",
])
@pytest.mark.parametrize(
"fusion_op",
[
pytest.param(AllReduceFusionOp.NONE, id="none"),
pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM,
id="residual_rms_norm"),
pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
id="residual_rms_norm_quant_fp8"),
pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8,
id="residual_rms_norm_out_quant_fp8"),
pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
id="residual_rms_norm_quant_nvfp4",
marks=skip_pre_blackwell),
pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
id="residual_rms_norm_out_quant_nvfp4",
marks=skip_pre_blackwell),
],
)
def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op):
torch.manual_seed(0)
dtype = torch.bfloat16
@ -276,3 +292,163 @@ def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op):
)
for r in results:
assert r is True
@torch.inference_mode()
def run_moe_allreduce_op(token_input: torch.Tensor, residual: torch.Tensor,
active_experts_token_input: torch.Tensor,
scale: torch.Tensor, tensor_parallel_size: int,
tensor_parallel_rank: int, l0_weight: torch.Tensor):
torch.manual_seed(42)
# * token_input:
# [num_token, 7168]
# different val for different device
# * active_experts_token_input
# [num_global_exp, num_token, 7168]
# need to slice to [num_device_exp, num_token, 7168] before use
# * scale
# [num_global_exp, num_token]
# per expert per token scale
# need to slice to [num_device_exp, num_token, 7168] before use
# different value for each device
token_input = token_input.cuda()
residual = residual.cuda()
active_experts_token_input = active_experts_token_input.cuda()
scale = scale.cuda()
dtype = token_input.dtype
num_global_experts = scale.size(0)
num_device_experts = num_global_experts // tensor_parallel_size
tensor_num_device_experts = torch.tensor(num_device_experts,
dtype=torch.int32,
device="cuda")
# num_token = token_input.shape[0]
hidden_size = token_input.shape[1]
# Setup parameters
eps = 1e-5
norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda")
# Initialize MoEAllreduce
moe_allreduce = MoEAllReduce(mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
)).cuda()
# Initialize RMSNorm
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()
norm.weight.data.copy_(norm_weight)
l0 = Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
).cuda()
l0.load_weights([dict(weight=l0_weight)])
token_input_chunked = torch.chunk(token_input.clone(),
tensor_parallel_size,
dim=-1)
fc2_output = l0(
token_input_chunked[tensor_parallel_rank],
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=norm_weight,
eps=eps,
enable_allreduce=False,
),
)
# Define fusion operation
# slice [num_global_exp, num_token, 7168] -> [num_device_exp, num_token, 7168]
active_experts_token_input_parallel = torch.chunk(
active_experts_token_input.clone(), tensor_parallel_size, dim=0)
active_experts_token_equalized = active_experts_token_input_parallel[
tensor_parallel_rank]
# slice [num_global_exp, num_token] -> [num_device_exp, num_token]
scale_parallel = torch.chunk(scale.clone(), tensor_parallel_size, dim=0)
scale_equalized = scale_parallel[tensor_parallel_rank]
# Run with fusion
output_hidden_states, output_residual = moe_allreduce(
residual,
norm_weight,
tensor_num_device_experts,
scale_equalized,
active_experts_token_equalized,
fc2_output,
eps,
)
torch_l0 = torch.nn.Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
torch_l0.weight.data.copy_(l0_weight)
torch_l0.cuda()
torch_linear_output = torch_l0(token_input)
# Verify with torch reference implementation
expert_reduction = torch.sum(active_experts_token_input *
scale.unsqueeze(-1),
dim=0)
torch_before_residual = expert_reduction + torch_linear_output
torch_residual = torch_before_residual + residual
torch_residual = torch_residual.to(torch.float32)
torch_output_hidden_states = rms_norm(torch_residual, norm_weight,
eps).to(dtype)
# Verify results are close to reference
torch.testing.assert_close(
output_hidden_states,
torch_output_hidden_states,
rtol=0.2,
atol=0.2,
)
return True
@torch.inference_mode()
def test_moe_allreduce_patterns():
torch.manual_seed(42)
seq_len = 16
hidden_size = 7168
dtype = torch.bfloat16
tensor_parallel_size = 2
num_global_experts = 4
# [num_token, 7168]
token_input = torch.randn((seq_len, hidden_size), dtype=dtype)
# [num_global_exp, num_token, 7168]
active_experts_token_input = torch.randn(
(num_global_experts, seq_len, hidden_size), dtype=dtype, device="cuda")
# [num_global_exp, num_token]
scale = torch.randn((num_global_experts, seq_len),
dtype=torch.float32,
device="cuda")
# [num_token, 7168]
residual = torch.randn_like(token_input)
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_moe_single_rank,
*zip(*[(tensor_parallel_size, run_moe_allreduce_op, token_input,
residual, active_experts_token_input, scale, l0_weight)] *
tensor_parallel_size),
)
for r in results:
assert r is True

View File

@ -19,15 +19,13 @@ import traceback
import cloudpickle
import pytest
import torch
import torch.nn as nn
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from utils.util import skip_pre_blackwell
import tensorrt_llm
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, DeepseekAllReduce)
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
DeepseekAllReduce)
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.mapping import Mapping
@ -227,166 +225,3 @@ def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion_op):
)
for r in results:
assert r is True
@torch.inference_mode()
def moe_residual_norm_fusion_forward(
token_input: torch.Tensor, residual: torch.Tensor,
active_experts_token_input: torch.Tensor, scale: torch.Tensor,
tensor_parallel_size: int, tensor_parallel_rank: int,
l0_weight: torch.Tensor):
torch.manual_seed(42)
# * token_input:
# [num_token, 7168]
# different val for different device
# * active_experts_token_input
# [num_global_exp, num_token, 7168]
# need to slice to [num_device_exp, num_token, 7168] before use
# * scale
# [num_global_exp, num_token]
# per expert per token scale
# need to slice to [num_device_exp, num_token, 7168] before use
# different value for each device
token_input = token_input.cuda()
residual = residual.cuda()
active_experts_token_input = active_experts_token_input.cuda()
scale = scale.cuda()
dtype = token_input.dtype
num_global_experts = scale.size(0)
num_device_experts = num_global_experts // tensor_parallel_size
tensor_num_device_experts = torch.tensor(num_device_experts,
dtype=torch.int32,
device="cuda")
# num_token = token_input.shape[0]
hidden_size = token_input.shape[1]
# Setup parameters
eps = 1e-5
norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda")
# Initialize DeepseekAllReduce and AllReduce
deepseek_allreduce = DeepseekAllReduce(mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
)).cuda()
# Initialize RMSNorm
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()
norm.weight.data.copy_(norm_weight)
l0 = Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
).cuda()
l0.load_weights([dict(weight=l0_weight)])
token_input_chunked = torch.chunk(token_input.clone(),
tensor_parallel_size,
dim=-1)
fc2_output = l0(
token_input_chunked[tensor_parallel_rank],
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=norm_weight,
eps=eps,
enable_allreduce=False,
),
)
# Define fusion operation
# slice [num_global_exp, num_token, 7168] -> [num_device_exp, num_token, 7168]
active_experts_token_input_parallel = torch.chunk(
active_experts_token_input.clone(), tensor_parallel_size, dim=0)
active_experts_token_equalized = active_experts_token_input_parallel[
tensor_parallel_rank]
# slice [num_global_exp, num_token] -> [num_device_exp, num_token]
scale_parallel = torch.chunk(scale.clone(), tensor_parallel_size, dim=0)
scale_equalized = scale_parallel[tensor_parallel_rank]
fusion_op = AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM
# Run with fusion
final_hidden_states, updated_residual = deepseek_allreduce(
token_input.clone(), [
residual.clone(),
norm_weight.clone(),
tensor_num_device_experts,
scale_equalized.clone(),
active_experts_token_equalized,
fc2_output,
], eps, fusion_op)
torch_l0 = nn.Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
torch_l0.weight.data.copy_(l0_weight)
torch_l0.cuda()
torch_linear_output = torch_l0(token_input)
# Verify with torch reference implementation
expert_reduction = torch.sum(active_experts_token_input *
scale.unsqueeze(-1),
dim=0)
torch_before_residual = (expert_reduction + torch_linear_output)
torch_residual = torch_before_residual + residual
torch_residual = torch_residual.to(torch.float32)
torch_final_hidden_states = rms_norm(torch_residual, norm_weight,
eps).to(dtype)
# Verify results are close to reference
torch.testing.assert_close(
final_hidden_states,
torch_final_hidden_states,
rtol=0.2,
atol=0.2,
)
return True
@torch.inference_mode()
def test_moe_residual_norm_fusion():
torch.manual_seed(42)
seq_len = 16
hidden_size = 7168
dtype = torch.bfloat16
tensor_parallel_size = 2
num_global_experts = 4
# [num_token, 7168]
token_input = torch.randn((seq_len, hidden_size), dtype=dtype)
# [num_global_exp, num_token, 7168]
active_experts_token_input = torch.randn(
(num_global_experts, seq_len, hidden_size), dtype=dtype, device="cuda")
# [num_global_exp, num_token]
scale = torch.randn((num_global_experts, seq_len),
dtype=torch.float32,
device="cuda")
# [num_token, 7168]
residual = torch.randn_like(token_input)
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_moe_single_rank,
*zip(*[(tensor_parallel_size, moe_residual_norm_fusion_forward,
token_input, residual, active_experts_token_input, scale,
l0_weight)] * tensor_parallel_size),
)
for r in results:
assert r is True