Overlap gen_fc2_alpha with fc1 using multistream in DenseGEMMFusedMoE

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-12-24 02:10:38 -08:00
parent 5ddbe3ca76
commit 578c0a8e28
2 changed files with 24 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from tensorrt_llm.quantization.utils import fp4_utils
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor, swizzle_sf, unswizzle_sf
from .fused_moe_cutlass import CutlassFusedMoE
from .interface import AlltoallMethodType
from .quantization import MoEWeightLoadingMode
@ -137,6 +137,16 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
init_load_balancer=init_load_balancer,
without_comm=without_comm,
)
# Initialize auxiliary stream and events for gen_fc2_alpha overlap with fc1
if self.aux_stream_dict is None:
self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {}
if AuxStreamType.MoeFc2Alpha not in self.aux_stream_dict:
self.aux_stream_dict[AuxStreamType.MoeFc2Alpha] = torch.cuda.Stream()
if self.event_dict is None:
self.event_dict = {}
for key in [EventType.Main, EventType.MoeFc2Alpha]:
if key not in self.event_dict:
self.event_dict[key] = torch.cuda.Event()
def load_weights(self, weights: List[Dict]):
super().load_weights(weights)
@ -272,6 +282,7 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
output_dtype = torch.bfloat16
num_tokens = x.shape[0]
self.event_dict[EventType.Main].record()
x_sf = swizzle_sf(x_sf, num_tokens, self.hidden_size)
fc1_output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
x,
@ -282,17 +293,22 @@ class DenseGEMMFusedMoE(CutlassFusedMoE):
output_dtype,
)
with torch.cuda.stream(self.aux_stream_dict[AuxStreamType.MoeFc2Alpha]):
self.event_dict[EventType.Main].wait()
fc2_alpha = self.gen_fc2_alpha(
num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha
)
alpha_normal_factor = self.fc2_alpha_max
alpha_normalized = fc2_alpha / alpha_normal_factor
self.event_dict[EventType.MoeFc2Alpha].record()
fc1_output = fc1_output.view(
[num_tokens, self.expert_size_per_partition, -1]
) * self.fc31_alpha.view([1, -1, 1])
fc1_output = fc1_output.to(output_dtype)
swiglu_out = swiglu_fused_moe(fc1_output)
alpha = self.gen_fc2_alpha(
num_tokens, token_selected_experts, token_final_scales, self.fc2_alpha
)
alpha_normal_factor = self.fc2_alpha_max
alpha_normalized = alpha / alpha_normal_factor
self.event_dict[EventType.MoeFc2Alpha].wait()
fc2_input = (
(swiglu_out * alpha_normalized.view([num_tokens, self.expert_size_per_partition, 1]))
.view([num_tokens, -1])

View File

@ -22,6 +22,7 @@ aux_stream_name_list = [
'MoeChunkingOverlap',
'MoeBalancer',
'MoeOutputMemset',
'MoeFc2Alpha',
]
AuxStreamType = Enum(
'AuxStreamType',