[feat] Fusion finalize and allreduce for qwenmoe model (#5223)

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-06-19 08:03:58 +08:00 committed by GitHub
parent 1a7c6e7974
commit 2b23cd56ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,7 @@ from transformers import Qwen3MoeConfig
from ..attention_backend import AttentionMetadata
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
allgather)
MoEAllReduce, MoEAllReduceParams, allgather)
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
@ -119,6 +119,7 @@ class Qwen3MoE(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
all_reduce_params: Optional[AllReduceParams] = None,
do_finalize: Optional[bool] = True,
) -> torch.Tensor:
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
@ -126,6 +127,9 @@ class Qwen3MoE(nn.Module):
use_dp_padding = False
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
if not do_finalize:
assert not self.enable_attention_dp
if self.enable_attention_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
@ -148,7 +152,12 @@ class Qwen3MoE(nn.Module):
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding)
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
if not do_finalize:
return final_hidden_states
if not self.enable_attention_dp and self.mapping.tp_size > 1:
final_hidden_states = self.allreduce(
@ -162,6 +171,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
layer_idx: int, aux_stream: torch.cuda.Stream):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.self_attn = Qwen3Attention(
model_config,
@ -198,6 +208,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1
or self.enable_attention_dp)
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)
def forward(
self,
@ -236,25 +247,55 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
do_finalize = not (hidden_states.shape[0]
<= self.moe_allreduce.max_token
and self.fusion_config.POST_MOE_FUSION
and self.model_config.moe_backend == 'TRTLLM'
and self.mlp.experts.has_nvfp4)
hidden_states = self.mlp(
hidden_states,
attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)))
or self.mapping.tp_size == 1)),
do_finalize=do_finalize,
)
if spec_metadata:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.fusion_config.POST_MOE_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
if do_finalize:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
assert len(
hidden_states
) == 3, f"hidden_states must have 3 elements, but got {len(hidden_states)}"
fc2_output = hidden_states[0]
expert_scale_factor = hidden_states[1]
expanded_idx_to_permuted_idx = hidden_states[2]
moe_all_reduce_params = MoEAllReduceParams(
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
expert_scale_factor=expert_scale_factor,
shared_expert_output=None,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
is_cutlass_min_latency=False,
)
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(