mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[feat] Fusion finalize and allreduce for qwenmoe model (#5223)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
This commit is contained in:
parent
1a7c6e7974
commit
2b23cd56ce
@ -8,7 +8,7 @@ from transformers import Qwen3MoeConfig
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
||||
allgather)
|
||||
MoEAllReduce, MoEAllReduceParams, allgather)
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
@ -119,6 +119,7 @@ class Qwen3MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
do_finalize: Optional[bool] = True,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.shape[-1] == self.hidden_dim
|
||||
orig_shape = hidden_states.shape
|
||||
@ -126,6 +127,9 @@ class Qwen3MoE(nn.Module):
|
||||
use_dp_padding = False
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
|
||||
if not do_finalize:
|
||||
assert not self.enable_attention_dp
|
||||
|
||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
@ -148,7 +152,12 @@ class Qwen3MoE(nn.Module):
|
||||
hidden_states,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding)
|
||||
use_dp_padding=use_dp_padding,
|
||||
do_finalize=do_finalize,
|
||||
)
|
||||
|
||||
if not do_finalize:
|
||||
return final_hidden_states
|
||||
|
||||
if not self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
final_hidden_states = self.allreduce(
|
||||
@ -162,6 +171,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
|
||||
layer_idx: int, aux_stream: torch.cuda.Stream):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
config = model_config.pretrained_config
|
||||
self.self_attn = Qwen3Attention(
|
||||
model_config,
|
||||
@ -198,6 +208,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
|
||||
or self.mapping.tp_size == 1
|
||||
or self.enable_attention_dp)
|
||||
self.moe_allreduce = MoEAllReduce(mapping=model_config.mapping)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -236,25 +247,55 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
# Note: this fusion pattern is only supported for TRTLLM-nvfp4 backend now
|
||||
do_finalize = not (hidden_states.shape[0]
|
||||
<= self.moe_allreduce.max_token
|
||||
and self.fusion_config.POST_MOE_FUSION
|
||||
and self.model_config.moe_backend == 'TRTLLM'
|
||||
and self.mlp.experts.has_nvfp4)
|
||||
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)))
|
||||
or self.mapping.tp_size == 1)),
|
||||
do_finalize=do_finalize,
|
||||
)
|
||||
|
||||
if spec_metadata:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
||||
if do_finalize:
|
||||
hidden_states, residual = self.allreduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
else:
|
||||
assert len(
|
||||
hidden_states
|
||||
) == 3, f"hidden_states must have 3 elements, but got {len(hidden_states)}"
|
||||
|
||||
fc2_output = hidden_states[0]
|
||||
expert_scale_factor = hidden_states[1]
|
||||
expanded_idx_to_permuted_idx = hidden_states[2]
|
||||
|
||||
moe_all_reduce_params = MoEAllReduceParams(
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
expert_scale_factor=expert_scale_factor,
|
||||
shared_expert_output=None,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
is_cutlass_min_latency=False,
|
||||
)
|
||||
hidden_states, residual = self.moe_allreduce(
|
||||
fc2_output, all_reduce_params=moe_all_reduce_params)
|
||||
else:
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user