diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py index 25e19ee43a..deaf570b31 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py @@ -9,7 +9,7 @@ from tensorrt_llm.quantization.utils import fp4_utils from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf from .fused_moe_cutlass import CutlassFusedMoE from .interface import AlltoallMethodType from .quantization import MoEWeightLoadingMode @@ -137,6 +137,16 @@ class DenseGEMMFusedMoE(CutlassFusedMoE): init_load_balancer=init_load_balancer, without_comm=without_comm, ) + # Initialize auxiliary stream and events for gen_fc2_alpha overlap with fc1 + if self.aux_stream_dict is None: + self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} + if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict: + self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = torch.cuda.Stream() + if self.event_dict is None: + self.event_dict = {} + for key in [EventType.Main, EventType.MoeFc2Alpha]: + if key not in self.event_dict: + self.event_dict[key] = torch.cuda.Event() def load_weights(self, weights: List[Dict]): super().load_weights(weights) @@ -272,6 +282,7 @@ class DenseGEMMFusedMoE(CutlassFusedMoE): output_dtype = torch.bfloat16 num_tokens = x.shape[0] + self.event_dict[EventType.Main].record() x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size) fc1_output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell( x, @@ -282,17 +293,22 @@ class DenseGEMMFusedMoE(CutlassFusedMoE): output_dtype, ) + with torch.cuda.stream(self.aux_stream_dict[AuxStreamType.MoeFc2Alpha]): + self.event_dict[EventType.Main].wait() + fc2_alpha = self.gen_fc2_alpha( + num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha + ) + alpha_normal_factor = self.fc2_alpha_max + alpha_normalized = fc2_alpha / alpha_normal_factor + self.event_dict[EventType.MoeFc2Alpha].record() + fc1_output = fc1_output.view( [num_tokens, self.expert_size_per_partition, -1] ) * self.fc31_alpha.view([1, -1, 1]) fc1_output = fc1_output.to(output_dtype) - swiglu_out = swiglu_fused_moe(fc1_output) - alpha = self.gen_fc2_alpha( - num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha - ) - alpha_normal_factor = self.fc2_alpha_max - alpha_normalized = alpha / alpha_normal_factor + + self.event_dict[EventType.MoeFc2Alpha].wait() fc2_input = ( (swiglu_out * alpha_normalized.view([num_tokens, self.expert_size_per_partition, 1])) .view([num_tokens, -1]) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 55276832ff..a9ce7a4c73 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -22,6 +22,7 @@ aux_stream_name_list = [ 'MoeChunkingOverlap', 'MoeBalancer', 'MoeOutputMemset', + 'MoeFc2Alpha', ] AuxStreamType = Enum( 'AuxStreamType',