From 9cc5922a0bd32f7737a2354b6e1f9a9d8405c18d Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Thu, 1 May 2025 07:56:36 +0800 Subject: [PATCH] Clean up allreduce op in Deepseek V3 model. (#3829) * Replace deepseek_allreduce op with the new unified allreduce op and moe_allreduce op. * Minor revision of moe_allreduce op argument names. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- cpp/tensorrt_llm/thop/allreduceOp.cpp | 70 +++--- .../_torch/custom_ops/cpp_custom_ops.py | 8 +- tensorrt_llm/_torch/distributed/__init__.py | 4 +- tensorrt_llm/_torch/distributed/ops.py | 65 +++++ .../_torch/models/modeling_deepseekv3.py | 234 ++++++------------ .../_torch/multi_gpu/test_allreduce.py | 212 ++++++++++++++-- .../multi_gpu/test_deepseek_allreduce.py | 167 +------------ 7 files changed, 383 insertions(+), 377 deletions(-) diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index fc8647e140..a244c554f7 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -273,12 +273,11 @@ private: torch::optional const& residual, torch::optional const& norm_weight, torch::optional const& scale, torch::optional const& bias) noexcept { + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); int size = input.numel(); int hidden_size = input.size(-1); - torch::Tensor output = torch::empty_like(input); - if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM) { torch::Tensor norm_out = torch::empty_like(input); @@ -815,14 +814,14 @@ std::vector allreduce(torch::Tensor input, torch::optional moe_allreduce(torch::Tensor residual, torch::Tensor norm_weight, - torch::Tensor moe_reduction_device_num_experts, torch::Tensor moe_reduction_scale_input, - torch::Tensor moe_reduction_active_experts_token_input, torch::Tensor moe_reduction_token_input, - torch::optional workspace, int64_t const rank, int64_t const nranks, double const eps) +// device_num_experts [1] +// scale_input [global_num_experts, m] +// active_experts_token_input [device_num_experts, m, hidden_dim] +// token_input [m, hidden_dim] +std::vector moe_allreduce(torch::Tensor const& residual, torch::Tensor const& norm_weight, + torch::Tensor const& device_num_experts, torch::Tensor const& scale_input, + torch::Tensor const& active_experts_token_input, torch::Tensor const& token_input, torch::Tensor workspace, + int64_t const rank, int64_t const nranks, double const eps) { auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams(); @@ -833,14 +832,13 @@ std::vector moe_allreduce(torch::Tensor residual, torch::Tensor n allreduce_fusion_params.nranks = static_cast(nranks); allreduce_fusion_params.rank = static_cast(rank); - allreduce_fusion_params.dtype - = tensorrt_llm::runtime::TorchUtils::dataType(moe_reduction_token_input.scalar_type()); + allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(token_input.scalar_type()); // size: num_token * hidden_dim - allreduce_fusion_params.size = static_cast(moe_reduction_token_input.numel()); - allreduce_fusion_params.hidden_dim = static_cast(moe_reduction_active_experts_token_input.size(-1)); + allreduce_fusion_params.size = static_cast(token_input.numel()); + allreduce_fusion_params.hidden_dim = static_cast(active_experts_token_input.size(-1)); // workspace: AR scratch space - allreduce_fusion_params.workspace = reinterpret_cast(workspace.value().mutable_data_ptr()); + allreduce_fusion_params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); allreduce_fusion_params.rms_gamma = norm_weight.data_ptr(); allreduce_fusion_params.rms_eps = static_cast(eps); @@ -850,15 +848,13 @@ std::vector moe_allreduce(torch::Tensor residual, torch::Tensor n // MOE Reduction specific params allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr - allreduce_fusion_params.moe_reduction_device_num_experts - = static_cast(moe_reduction_device_num_experts.data_ptr()); - allreduce_fusion_params.moe_reduction_scale_input = static_cast(moe_reduction_scale_input.data_ptr()); - allreduce_fusion_params.moe_reduction_active_experts_token_input - = moe_reduction_active_experts_token_input.data_ptr(); - allreduce_fusion_params.moe_reduction_token_input = moe_reduction_token_input.data_ptr(); + allreduce_fusion_params.moe_reduction_device_num_experts = static_cast(device_num_experts.data_ptr()); + allreduce_fusion_params.moe_reduction_scale_input = static_cast(scale_input.data_ptr()); + allreduce_fusion_params.moe_reduction_active_experts_token_input = active_experts_token_input.data_ptr(); + allreduce_fusion_params.moe_reduction_token_input = token_input.data_ptr(); // output tensors - torch::Tensor norm_out = torch::empty_like(moe_reduction_token_input); + torch::Tensor norm_out = torch::empty_like(token_input); torch::Tensor residual_out = torch::empty_like(residual); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); @@ -874,15 +870,29 @@ std::vector moe_allreduce(torch::Tensor residual, torch::Tensor n TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "allreduce(Tensor input, Tensor? residual, Tensor? norm_weight, Tensor? scale, Tensor? bias, Tensor? " - "workspace, int[] group, int " - "strategy, int op, float eps) -> Tensor[]"); + "allreduce(" + "Tensor input," + "Tensor? residual," + "Tensor? norm_weight," + "Tensor? scale," + "Tensor? bias," + "Tensor? workspace," + "int[] group," + "int strategy," + "int op," + "float eps) -> Tensor[]"); m.def( - "moe_allreduce(Tensor residual, Tensor norm_weight, Tensor " - "moe_reduction_device_num_experts, " - "Tensor moe_reduction_scale_input, Tensor moe_reduction_active_experts_token_input, Tensor " - "moe_reduction_token_input, Tensor? workspace, " - "int rank, int nranks, float eps) -> Tensor[]"); + "moe_allreduce(" + "Tensor residual," + "Tensor norm_weight," + "Tensor device_num_experts," + "Tensor scale_input," + "Tensor active_experts_token_input," + "Tensor token_input," + "Tensor workspace," + "int rank," + "int nranks," + "float eps) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 522e675cd7..b4d8522080 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -53,10 +53,10 @@ def _register_fake(): return [torch.empty_like(input)] @torch.library.register_fake("trtllm::moe_allreduce") - def _(residual, norm_weight, moe_reduction_device_num_experts, - moe_reduction_scale_input, moe_reduction_active_experts_token_input, - moe_reduction_token_input, workspace, rank, nranks, eps): - norm_out = torch.empty_like(moe_reduction_token_input) + def _(residual, norm_weight, device_num_experts, scale_input, + active_experts_token_input, token_input, workspace, rank, nranks, + eps): + norm_out = torch.empty_like(token_input) residual_out = torch.empty_like(residual) return [norm_out, residual_out] diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 442a8d454a..f18de7bbcd 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -1,11 +1,10 @@ from .communicator import Distributed, MPIDist, PPComm, TorchDist from .ops import (AllReduce, AllReduceFusionOp, AllReduceParams, - AllReduceStrategy, DeepseekAllReduce, allgather, + AllReduceStrategy, DeepseekAllReduce, MoEAllReduce, allgather, reducescatter, userbuffers_allreduce_finalize) __all__ = [ "allgather", - "allreduce", "reducescatter", "userbuffers_allreduce_finalize", "AllReduce", @@ -13,6 +12,7 @@ __all__ = [ "AllReduceFusionOp", "AllReduceStrategy", "DeepseekAllReduce", + "MoEAllReduce", "TorchDist", "PPComm", "MPIDist", diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 949382d601..8e0f28221c 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -132,6 +132,10 @@ class AllReduce(nn.Module): - RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4 - AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy. + + Note: + For the reference implementation for each pattern, please refer to the following unit test: + https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py """ self.mapping = mapping @@ -196,6 +200,67 @@ class AllReduce(nn.Module): return output if len(output) > 1 else output[0] +class MoEAllReduce(nn.Module): + + def __init__(self, mapping: Mapping): + """ + MoEAllReduce is a module that performs a specific fused MoE reduction + followed by a regular AR + RMS norm. + + Args: + mapping (Mapping): The parallel mapping config. + + Notes: + Support pattern: MoE Reduction + Add + AR + ADD_RMS, see this torch reference implementation: + expert_reduction = torch.sum(active_experts_token_input * + scale.unsqueeze(-1), + dim=0) + output_add = expert_reduction + shared_expert_output + output_residual = output_add + residual + output_hidden_states = rms_norm(output_residual, norm_weight, eps) + """ + super().__init__() + self.mapping = mapping + self.workspace = get_allreduce_workspace(self.mapping) + + def forward( + self, + residual: torch.Tensor, + norm_weight: torch.Tensor, + device_num_experts: torch.Tensor, + scale_input: torch.Tensor, + active_experts_token_input: torch.Tensor, + token_input: torch.Tensor, + eps: float, + ) -> torch.Tensor: + """ + Args: + residual: residual tensor + norm_weight: RMS norm weight + device_num_experts: number of experts per device + scale_input: experts to token score + active_experts_token_input: per token per expert input + token_input: per token input, shared expert output + eps: epsilon for RMSNorm + + Output: + hidden_states: hidden_states of the model + residual: residual tensor + """ + return torch.ops.trtllm.moe_allreduce( + residual=residual, + norm_weight=norm_weight, + device_num_experts=device_num_experts, + scale_input=scale_input, + active_experts_token_input=active_experts_token_input, + token_input=token_input, + workspace=self.workspace, + rank=self.mapping.tp_rank, + nranks=self.mapping.tp_size, + eps=eps, + ) + + class DeepseekAllReduce(nn.Module): def __init__(self, mapping: Mapping): diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 26ddca743d..69c51bb1d3 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -45,7 +45,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - DeepseekAllReduce, allgather) + MoEAllReduce, allgather) from ..model_config import ModelConfig from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp from ..modules.attention import MLA @@ -392,7 +392,7 @@ class Deepseekv3MoE(nn.Module): overridden_tp_size=shared_tp_size, reduce_output=False) - self.all_reduce = AllReduce(self.mapping) + self.allreduce = AllReduce(self.mapping) self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] self.event_dict = { key: torch.cuda.Event() @@ -516,7 +516,7 @@ class Deepseekv3MoE(nn.Module): ), f'unmatched tensor shape' final_hidden_states = shared_output + routed_output if not self.use_dp and self.mapping.tp_size > 1: - final_hidden_states = self.all_reduce( + final_hidden_states = self.allreduce( final_hidden_states, all_reduce_params=final_all_reduce_params) @@ -608,17 +608,10 @@ class DeepseekV3DecoderLayer(DecoderLayer): eps=config.rms_norm_eps, dtype=config.torch_dtype) self.layer_idx = layer_idx - self.all_reduce = AllReduce(self.mapping) + self.allreduce = AllReduce(self.mapping) + self.moe_allreduce = MoEAllReduce(self.mapping) self.next_layer_layernorm: RMSNorm = None - self.deepseek_allreduce_disabled = os.environ.get( - "TRTLLM_DEEPSEEK_ALLREDUCE_FUSION_DISABLED", "0") == "1" - if mapping.is_multi_node(): - self.deepseek_allreduce_disabled = True - - if not self.deepseek_allreduce_disabled: - self.deepseek_allreduce = DeepseekAllReduce(self.mapping) - def _compute_mlp_tp_size(self, intermediate_size: int, block_size: int) -> int: """ @@ -675,19 +668,25 @@ class DeepseekV3DecoderLayer(DecoderLayer): **kwargs, ) - # deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive - using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size( - 0) > 128 - - min_latency_mode = self._enable_latency_mode( - hidden_states.size(0)) and not using_prev_fusion + min_latency_mode = self._enable_latency_mode(hidden_states.size(0)) hidden_states_fp4 = None if self.fusion_config.PRE_MOE_FUSION: - # Custom AR Fusion for DeepseekV3 - if using_prev_fusion: - # Custom AR Fusion for DeepseekV3 - hidden_states, residual = self.all_reduce( + if min_latency_mode: + hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp. + RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.experts.fc31_input_scale, + eps=self.post_attention_layernorm.variance_epsilon, + )) + hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act, + hidden_states_sf) + else: + hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, @@ -695,52 +694,17 @@ class DeepseekV3DecoderLayer(DecoderLayer): norm_weight=self.post_attention_layernorm.weight, eps=self.post_attention_layernorm.variance_epsilon, )) - else: - if min_latency_mode: - hidden_states, hidden_states_act, hidden_states_sf, residual = self.deepseek_allreduce( - hidden_states, - [ - residual, self.post_attention_layernorm.weight, - self.mlp.experts.fc31_input_scale - ], - self.post_attention_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, - ) - hidden_states_fp4 = Fp4QuantizedTensor( - hidden_states_act, hidden_states_sf) - else: - hidden_states, residual = self.deepseek_allreduce( - hidden_states, - [residual, self.post_attention_layernorm.weight], - self.post_attention_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM, - ) elif self.fusion_config.PRE_MLP_FUSION: - # Custom AR Fusion for DeepseekV3 with quant_fp4 - if using_prev_fusion: - hidden_states, residual = self.all_reduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - eps=self.post_attention_layernorm.variance_epsilon, - )) - act_fp4, act_sf = torch.ops.trtllm.fp4_quantize( - hidden_states, self.mlp.gate_up_proj.input_scale, - self.mlp.gate_up_proj.scaling_vector_size, False) - else: - act_fp4, act_sf, residual = self.deepseek_allreduce( - hidden_states, - [ - residual, self.post_attention_layernorm.weight, - self.mlp.gate_up_proj.input_scale - ], - self.post_attention_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, - ) + act_fp4, act_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.gate_up_proj.input_scale, + eps=self.post_attention_layernorm.variance_epsilon, + )) hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) - else: # No fusion hidden_states, residual = self.post_attention_layernorm( @@ -769,62 +733,39 @@ class DeepseekV3DecoderLayer(DecoderLayer): ) if self.fusion_config.POST_MOE_FUSION: - if using_prev_fusion: - hidden_states, residual = self.all_reduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - )) - else: - if min_latency_mode: - shared_output = hidden_states[0] - hidden_states_activated_experts = hidden_states[1] - num_activated_experts_per_node = hidden_states[2] - experts_to_token_score = hidden_states[3] - activated_expert_global_ids = hidden_states[4] + if min_latency_mode: + shared_output = hidden_states[0] + hidden_states_activated_experts = hidden_states[1] + num_activated_experts_per_node = hidden_states[2] + experts_to_token_score = hidden_states[3] - hidden_states, residual = self.deepseek_allreduce( - hidden_states_activated_experts, # not used - [ - residual, self.next_layer_layernorm.weight, - num_activated_experts_per_node, - experts_to_token_score, - hidden_states_activated_experts, shared_output, - activated_expert_global_ids - ], - self.next_layer_layernorm.variance_epsilon, - AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM, - ) - else: - hidden_states, residual = self.deepseek_allreduce( - hidden_states, - [residual, self.next_layer_layernorm.weight], - self.next_layer_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM, - ) - elif self.fusion_config.POST_MLP_FUSION: - - if using_prev_fusion: - # Custom AR Fusion for DeepseekV3 - hidden_states, residual = self.all_reduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - )) - else: - hidden_states, residual = self.deepseek_allreduce( - hidden_states, - [residual, self.next_layer_layernorm.weight], - self.next_layer_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM, + hidden_states, residual = self.moe_allreduce( + residual, + self.next_layer_layernorm.weight, + device_num_experts=num_activated_experts_per_node, + scale_input=experts_to_token_score, + active_experts_token_input=hidden_states_activated_experts, + token_input=shared_output, + eps=self.next_layer_layernorm.variance_epsilon, ) - + else: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + )) + elif self.fusion_config.POST_MLP_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + )) else: if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( @@ -878,9 +819,6 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): ) -> Tuple[torch.Tensor, torch.Tensor]: # deepseek allreduce kernel is better when m < 512 - using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size( - 0) >= 512 - inputs_embeds = self.enorm(embed_tokens(input_ids)) hidden_states = self.hnorm(hidden_states) hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) @@ -902,24 +840,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): # MTP Layer Must have sparse MOE if self.fusion_config.PRE_MOE_FUSION: - # Custom AR Fusion for DeepseekV3 - if using_prev_fusion: - # Custom AR Fusion for DeepseekV3 - hidden_states, residual = self.all_reduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - eps=self.post_attention_layernorm.variance_epsilon, - )) - else: - hidden_states, residual = self.deepseek_allreduce( - hidden_states, - [residual, self.post_attention_layernorm.weight], - self.post_attention_layernorm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM, - ) + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + )) else: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -933,22 +861,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): ) if self.fusion_config.POST_MOE_FUSION: - if using_prev_fusion: - hidden_states, residual = self.all_reduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.shared_head.norm.weight, - eps=self.shared_head.norm.variance_epsilon, - )) - else: - hidden_states, residual = self.deepseek_allreduce( - hidden_states, - [residual, self.shared_head.norm.weight], - self.shared_head.norm.variance_epsilon, - AllReduceFusionOp.RESIDUAL_RMS_NORM, - ) + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.shared_head.norm.weight, + eps=self.shared_head.norm.variance_epsilon, + )) else: hidden_states, _ = self.shared_head.norm(hidden_states, residual) diff --git a/tests/unittest/_torch/multi_gpu/test_allreduce.py b/tests/unittest/_torch/multi_gpu/test_allreduce.py index f48c8f07f1..c9fe1731a8 100644 --- a/tests/unittest/_torch/multi_gpu/test_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_allreduce.py @@ -26,7 +26,7 @@ from utils.util import skip_pre_blackwell import tensorrt_llm from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, - AllReduceParams) + AllReduceParams, MoEAllReduce) from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm.mapping import Mapping @@ -72,6 +72,21 @@ def run_single_rank(tensor_parallel_size, single_rank_forward_func, input, return True +def run_moe_single_rank(tensor_parallel_size, single_rank_forward_func, + token_input, residual, active_experts_token_input, + scale, l0_weight): + rank = tensorrt_llm.mpi_rank() + torch.cuda.set_device(rank) + try: + single_rank_forward_func(token_input, residual, + active_experts_token_input, scale, + tensor_parallel_size, rank, l0_weight) + except Exception: + traceback.print_exc() + raise + return True + + @torch.inference_mode() def run_allreduce_op(x: torch.Tensor, residual: torch.Tensor, hidden_size: int, dtype: torch.dtype, tensor_parallel_size: int, @@ -238,28 +253,29 @@ def run_allreduce_op(x: torch.Tensor, residual: torch.Tensor, hidden_size: int, assert mismatch_percentage < 0.01, f"Large mismatched elements encountered" -@skip_pre_blackwell @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs for this test") @pytest.mark.parametrize("seq_len", [16, 256], ids=lambda x: f"seqlen:{x}") @pytest.mark.parametrize("hidden_size", [128, 7168], ids=lambda x: f"hidden:{x}") -@pytest.mark.parametrize("fusion_op", [ - AllReduceFusionOp.NONE, - AllReduceFusionOp.RESIDUAL_RMS_NORM, - AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, - AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8, - AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, - AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, -], - ids=[ - "none", - "residual_rms_norm", - "residual_rms_norm_quant_fp8", - "residual_rms_norm_out_quant_fp8", - "residual_rms_norm_quant_nvfp4", - "residual_rms_norm_out_quant_nvfp4", - ]) +@pytest.mark.parametrize( + "fusion_op", + [ + pytest.param(AllReduceFusionOp.NONE, id="none"), + pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM, + id="residual_rms_norm"), + pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, + id="residual_rms_norm_quant_fp8"), + pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8, + id="residual_rms_norm_out_quant_fp8"), + pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, + id="residual_rms_norm_quant_nvfp4", + marks=skip_pre_blackwell), + pytest.param(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, + id="residual_rms_norm_out_quant_nvfp4", + marks=skip_pre_blackwell), + ], +) def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op): torch.manual_seed(0) dtype = torch.bfloat16 @@ -276,3 +292,163 @@ def test_allreduce_fusion_patterns(seq_len, hidden_size, fusion_op): ) for r in results: assert r is True + + +@torch.inference_mode() +def run_moe_allreduce_op(token_input: torch.Tensor, residual: torch.Tensor, + active_experts_token_input: torch.Tensor, + scale: torch.Tensor, tensor_parallel_size: int, + tensor_parallel_rank: int, l0_weight: torch.Tensor): + torch.manual_seed(42) + + # * token_input: + # [num_token, 7168] + # different val for different device + # * active_experts_token_input + # [num_global_exp, num_token, 7168] + # need to slice to [num_device_exp, num_token, 7168] before use + # * scale + # [num_global_exp, num_token] + # per expert per token scale + # need to slice to [num_device_exp, num_token, 7168] before use + # different value for each device + + token_input = token_input.cuda() + residual = residual.cuda() + active_experts_token_input = active_experts_token_input.cuda() + scale = scale.cuda() + + dtype = token_input.dtype + num_global_experts = scale.size(0) + num_device_experts = num_global_experts // tensor_parallel_size + tensor_num_device_experts = torch.tensor(num_device_experts, + dtype=torch.int32, + device="cuda") + # num_token = token_input.shape[0] + hidden_size = token_input.shape[1] + + # Setup parameters + eps = 1e-5 + norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda") + + # Initialize MoEAllreduce + moe_allreduce = MoEAllReduce(mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + )).cuda() + + # Initialize RMSNorm + norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() + norm.weight.data.copy_(norm_weight) + + l0 = Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=False, + dtype=dtype, + mapping=Mapping( + world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank, + ), + tensor_parallel_mode=TensorParallelMode.ROW, + ).cuda() + l0.load_weights([dict(weight=l0_weight)]) + token_input_chunked = torch.chunk(token_input.clone(), + tensor_parallel_size, + dim=-1) + fc2_output = l0( + token_input_chunked[tensor_parallel_rank], + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=norm_weight, + eps=eps, + enable_allreduce=False, + ), + ) + + # Define fusion operation + # slice [num_global_exp, num_token, 7168] -> [num_device_exp, num_token, 7168] + active_experts_token_input_parallel = torch.chunk( + active_experts_token_input.clone(), tensor_parallel_size, dim=0) + active_experts_token_equalized = active_experts_token_input_parallel[ + tensor_parallel_rank] + + # slice [num_global_exp, num_token] -> [num_device_exp, num_token] + scale_parallel = torch.chunk(scale.clone(), tensor_parallel_size, dim=0) + scale_equalized = scale_parallel[tensor_parallel_rank] + + # Run with fusion + output_hidden_states, output_residual = moe_allreduce( + residual, + norm_weight, + tensor_num_device_experts, + scale_equalized, + active_experts_token_equalized, + fc2_output, + eps, + ) + + torch_l0 = torch.nn.Linear(in_features=hidden_size, + out_features=hidden_size, + bias=False, + dtype=dtype) + torch_l0.weight.data.copy_(l0_weight) + torch_l0.cuda() + + torch_linear_output = torch_l0(token_input) + # Verify with torch reference implementation + expert_reduction = torch.sum(active_experts_token_input * + scale.unsqueeze(-1), + dim=0) + torch_before_residual = expert_reduction + torch_linear_output + torch_residual = torch_before_residual + residual + torch_residual = torch_residual.to(torch.float32) + torch_output_hidden_states = rms_norm(torch_residual, norm_weight, + eps).to(dtype) + + # Verify results are close to reference + torch.testing.assert_close( + output_hidden_states, + torch_output_hidden_states, + rtol=0.2, + atol=0.2, + ) + + return True + + +@torch.inference_mode() +def test_moe_allreduce_patterns(): + torch.manual_seed(42) + + seq_len = 16 + hidden_size = 7168 + dtype = torch.bfloat16 + tensor_parallel_size = 2 + num_global_experts = 4 + + # [num_token, 7168] + token_input = torch.randn((seq_len, hidden_size), dtype=dtype) + # [num_global_exp, num_token, 7168] + active_experts_token_input = torch.randn( + (num_global_experts, seq_len, hidden_size), dtype=dtype, device="cuda") + # [num_global_exp, num_token] + scale = torch.randn((num_global_experts, seq_len), + dtype=torch.float32, + device="cuda") + # [num_token, 7168] + residual = torch.randn_like(token_input) + + l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype) + with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor: + results = executor.map( + run_moe_single_rank, + *zip(*[(tensor_parallel_size, run_moe_allreduce_op, token_input, + residual, active_experts_token_input, scale, l0_weight)] * + tensor_parallel_size), + ) + for r in results: + assert r is True diff --git a/tests/unittest/_torch/multi_gpu/test_deepseek_allreduce.py b/tests/unittest/_torch/multi_gpu/test_deepseek_allreduce.py index 85b2fdabb1..ac6cda110f 100644 --- a/tests/unittest/_torch/multi_gpu/test_deepseek_allreduce.py +++ b/tests/unittest/_torch/multi_gpu/test_deepseek_allreduce.py @@ -19,15 +19,13 @@ import traceback import cloudpickle import pytest import torch -import torch.nn as nn from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from utils.util import skip_pre_blackwell import tensorrt_llm from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, - AllReduceParams, DeepseekAllReduce) -from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode + DeepseekAllReduce) from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm.mapping import Mapping @@ -227,166 +225,3 @@ def test_row_linear_residual_norm_fusion(seq_len, hidden_size, fusion_op): ) for r in results: assert r is True - - -@torch.inference_mode() -def moe_residual_norm_fusion_forward( - token_input: torch.Tensor, residual: torch.Tensor, - active_experts_token_input: torch.Tensor, scale: torch.Tensor, - tensor_parallel_size: int, tensor_parallel_rank: int, - l0_weight: torch.Tensor): - torch.manual_seed(42) - - # * token_input: - # [num_token, 7168] - # different val for different device - # * active_experts_token_input - # [num_global_exp, num_token, 7168] - # need to slice to [num_device_exp, num_token, 7168] before use - # * scale - # [num_global_exp, num_token] - # per expert per token scale - # need to slice to [num_device_exp, num_token, 7168] before use - # different value for each device - - token_input = token_input.cuda() - residual = residual.cuda() - active_experts_token_input = active_experts_token_input.cuda() - scale = scale.cuda() - - dtype = token_input.dtype - num_global_experts = scale.size(0) - num_device_experts = num_global_experts // tensor_parallel_size - tensor_num_device_experts = torch.tensor(num_device_experts, - dtype=torch.int32, - device="cuda") - # num_token = token_input.shape[0] - hidden_size = token_input.shape[1] - - # Setup parameters - eps = 1e-5 - norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda") - - # Initialize DeepseekAllReduce and AllReduce - deepseek_allreduce = DeepseekAllReduce(mapping=Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=tensor_parallel_rank, - )).cuda() - - # Initialize RMSNorm - norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() - norm.weight.data.copy_(norm_weight) - - l0 = Linear( - in_features=hidden_size, - out_features=hidden_size, - bias=False, - dtype=dtype, - mapping=Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=tensor_parallel_rank, - ), - tensor_parallel_mode=TensorParallelMode.ROW, - ).cuda() - l0.load_weights([dict(weight=l0_weight)]) - token_input_chunked = torch.chunk(token_input.clone(), - tensor_parallel_size, - dim=-1) - fc2_output = l0( - token_input_chunked[tensor_parallel_rank], - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=norm_weight, - eps=eps, - enable_allreduce=False, - ), - ) - - # Define fusion operation - # slice [num_global_exp, num_token, 7168] -> [num_device_exp, num_token, 7168] - active_experts_token_input_parallel = torch.chunk( - active_experts_token_input.clone(), tensor_parallel_size, dim=0) - active_experts_token_equalized = active_experts_token_input_parallel[ - tensor_parallel_rank] - - # slice [num_global_exp, num_token] -> [num_device_exp, num_token] - scale_parallel = torch.chunk(scale.clone(), tensor_parallel_size, dim=0) - scale_equalized = scale_parallel[tensor_parallel_rank] - - fusion_op = AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM - - # Run with fusion - final_hidden_states, updated_residual = deepseek_allreduce( - token_input.clone(), [ - residual.clone(), - norm_weight.clone(), - tensor_num_device_experts, - scale_equalized.clone(), - active_experts_token_equalized, - fc2_output, - ], eps, fusion_op) - - torch_l0 = nn.Linear(in_features=hidden_size, - out_features=hidden_size, - bias=False, - dtype=dtype) - torch_l0.weight.data.copy_(l0_weight) - torch_l0.cuda() - - torch_linear_output = torch_l0(token_input) - # Verify with torch reference implementation - expert_reduction = torch.sum(active_experts_token_input * - scale.unsqueeze(-1), - dim=0) - torch_before_residual = (expert_reduction + torch_linear_output) - torch_residual = torch_before_residual + residual - torch_residual = torch_residual.to(torch.float32) - torch_final_hidden_states = rms_norm(torch_residual, norm_weight, - eps).to(dtype) - - # Verify results are close to reference - torch.testing.assert_close( - final_hidden_states, - torch_final_hidden_states, - rtol=0.2, - atol=0.2, - ) - - return True - - -@torch.inference_mode() -def test_moe_residual_norm_fusion(): - torch.manual_seed(42) - - seq_len = 16 - hidden_size = 7168 - dtype = torch.bfloat16 - tensor_parallel_size = 2 - num_global_experts = 4 - - # [num_token, 7168] - token_input = torch.randn((seq_len, hidden_size), dtype=dtype) - # [num_global_exp, num_token, 7168] - active_experts_token_input = torch.randn( - (num_global_experts, seq_len, hidden_size), dtype=dtype, device="cuda") - # [num_global_exp, num_token] - scale = torch.randn((num_global_experts, seq_len), - dtype=torch.float32, - device="cuda") - # [num_token, 7168] - residual = torch.randn_like(token_input) - - l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype) - with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor: - results = executor.map( - run_moe_single_rank, - *zip(*[(tensor_parallel_size, moe_residual_norm_fusion_forward, - token_input, residual, active_experts_token_input, scale, - l0_weight)] * tensor_parallel_size), - ) - for r in results: - assert r is True