From 26a2679217f72a0b29e18390fc71856e9e7ca701 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Wed, 7 May 2025 10:43:10 -0700 Subject: [PATCH] [Deepseek] Refactor Deepseek Decoder layer (#4016) Refactor Deepseek Decoder layer Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com> Co-authored-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 219 +++++++++++------- tensorrt_llm/_torch/models/modeling_llama.py | 18 +- tensorrt_llm/_torch/modules/attention.py | 8 +- tensorrt_llm/_torch/modules/fused_moe.py | 29 +-- tensorrt_llm/_torch/modules/gated_mlp.py | 6 +- 5 files changed, 165 insertions(+), 115 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 69c51bb1d3..e64f0990d4 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -455,7 +455,7 @@ class Deepseekv3MoE(nn.Module): return True def compute_routed_output(self, hidden_states, hidden_states_fp4, - all_rank_num_tokens, min_latency_mode): + all_rank_num_tokens, cutlass_min_latency_mode): # max-throughput if self.use_dp and self.mapping.tp_size > 1 and not self.enable_alltoall: max_num_token = max(all_rank_num_tokens) @@ -473,7 +473,7 @@ class Deepseekv3MoE(nn.Module): routed_output = self.experts(hidden_states_fp4 or hidden_states, router_logits, - min_latency_mode, + cutlass_min_latency_mode, output_dtype=hidden_states.dtype, all_rank_num_tokens=all_rank_num_tokens) @@ -485,9 +485,9 @@ class Deepseekv3MoE(nn.Module): hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, all_rank_num_tokens: Optional[list[int]] = None, final_all_reduce_params: Optional[AllReduceParams] = None, - min_latency_mode: Optional[bool] = False, + cutlass_min_latency_mode: Optional[bool] = False, ) -> torch.Tensor: - if min_latency_mode: + if cutlass_min_latency_mode: assert not self.use_dp def _compute_shared_output(): @@ -498,10 +498,9 @@ class Deepseekv3MoE(nn.Module): return shared_output def _compute_routed_output(): - routed_output = self.compute_routed_output(hidden_states, - hidden_states_fp4, - all_rank_num_tokens, - min_latency_mode) + routed_output = self.compute_routed_output( + hidden_states, hidden_states_fp4, all_rank_num_tokens, + cutlass_min_latency_mode) return routed_output shared_output, routed_output = maybe_execute_in_parallel( @@ -509,7 +508,7 @@ class Deepseekv3MoE(nn.Module): self.event_dict[EventType.Main], self.event_dict[EventType.MoeShared], self.aux_stream) - if min_latency_mode: + if cutlass_min_latency_mode: return [shared_output, *routed_output] else: assert shared_output.size() == routed_output.size( @@ -531,6 +530,7 @@ class DeepseekV3DecoderLayer(DecoderLayer): super().__init__() self.model_config = model_config config = model_config.pretrained_config + self.hidden_size = config.hidden_size self.moe_intermediate_size = config.moe_intermediate_size self.num_experts = config.n_routed_experts @@ -544,16 +544,17 @@ class DeepseekV3DecoderLayer(DecoderLayer): model_config, layer_idx=layer_idx, aux_stream=aux_stream_dict[AuxStreamType.Attention]) - self.fusion_config = EagerFusionConfig() self.enable_attention_dp = mapping.enable_attention_dp + self.mlp_tp_size = mapping.tp_size pp_layer_offset = mapping.pp_layers(config.num_hidden_layers)[0] global_layer_idx = pp_layer_offset + layer_idx - enable_fusion = os.environ.get("TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", - "0") == "0" - self.enable_fusion = enable_fusion and not self.enable_attention_dp + self.fusion_config = EagerFusionConfig() + self.enable_fusion = os.environ.get( + "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0" + self.enable_fusion &= not self.enable_attention_dp # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4() @@ -584,8 +585,9 @@ class DeepseekV3DecoderLayer(DecoderLayer): self.mlp_tp_size = self._compute_mlp_tp_size( config.intermediate_size, block_size) - self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4 - self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp + has_mlp_tp = self.mlp_tp_size > 1 + self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 + self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp and not has_pp self.mlp = GatedMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -643,8 +645,10 @@ class DeepseekV3DecoderLayer(DecoderLayer): ) return mlp_tp_size - def _enable_latency_mode(self, num_tokens: int): - return num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS' + def _enable_min_latency_mode(self, num_tokens: int): + return (num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION + and self.is_nvfp4 + and self.model_config.moe_backend == 'CUTLASS') def forward( self, @@ -668,24 +672,78 @@ class DeepseekV3DecoderLayer(DecoderLayer): **kwargs, ) - min_latency_mode = self._enable_latency_mode(hidden_states.size(0)) + if isinstance(self.mlp, Deepseekv3MoE): + return self.forward_MoE( + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + ) + else: + assert isinstance(self.mlp, GatedMLP) + return self.forward_mlp( + hidden_states=hidden_states, + residual=residual, + ) - hidden_states_fp4 = None - if self.fusion_config.PRE_MOE_FUSION: - if min_latency_mode: - hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp. - RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - scale=self.mlp.experts.fc31_input_scale, - eps=self.post_attention_layernorm.variance_epsilon, - )) - hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act, - hidden_states_sf) - else: + def forward_MoE( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + ) -> torch.Tensor: + + def _run_MoE(hidden_states, hidden_states_fp4): + return self.mlp( + hidden_states, + hidden_states_fp4, + all_rank_num_tokens=attn_metadata.all_rank_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + cutlass_min_latency_mode=cutlass_min_latency_mode, + ) + + cutlass_min_latency_mode = self._enable_min_latency_mode( + hidden_states.shape[0]) + + if cutlass_min_latency_mode: + assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION + assert self.model_config.moe_backend == 'CUTLASS' + + hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp. + RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.experts.fc31_input_scale, + eps=self.post_attention_layernorm.variance_epsilon, + )) + hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act, + hidden_states_sf) + + hidden_states = _run_MoE(hidden_states, hidden_states_fp4) + + shared_output = hidden_states[0] + hidden_states_activated_experts = hidden_states[1] + num_activated_experts_per_node = hidden_states[2] + experts_to_token_score = hidden_states[3] + + # MoE_finalize is fused into allreduce + hidden_states, residual = self.moe_allreduce( + residual, + self.next_layer_layernorm.weight, + device_num_experts=num_activated_experts_per_node, + scale_input=experts_to_token_score, + active_experts_token_input=hidden_states_activated_experts, + token_input=shared_output, + eps=self.next_layer_layernorm.variance_epsilon, + ) + else: + if self.fusion_config.PRE_MOE_FUSION: + # moe_backend can be either CUTLASS or TRTLLM here + # TODO: unify the two min-latency MoE backends by enabling quant fusion hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -694,7 +752,36 @@ class DeepseekV3DecoderLayer(DecoderLayer): norm_weight=self.post_attention_layernorm.weight, eps=self.post_attention_layernorm.variance_epsilon, )) - elif self.fusion_config.PRE_MLP_FUSION: + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = _run_MoE(hidden_states, hidden_states_fp4=None) + + if self.fusion_config.POST_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + )) + else: + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + + if self.fusion_config.PRE_MLP_FUSION: act_fp4, act_sf, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -710,54 +797,13 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - if self.fusion_config.PRE_MOE_FUSION and min_latency_mode: - hidden_states = self.mlp( - hidden_states, - hidden_states_fp4, - all_rank_num_tokens=attn_metadata.all_rank_num_tokens, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MOE_FUSION - or self.fusion_config.POST_MLP_FUSION - or self.mapping.tp_size == 1 or self.enable_attention_dp)), - min_latency_mode=min_latency_mode, - ) - else: - hidden_states = self.mlp( - hidden_states, - all_rank_num_tokens=attn_metadata.all_rank_num_tokens, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MOE_FUSION - or self.fusion_config.POST_MLP_FUSION - or self.mapping.tp_size == 1 or self.enable_attention_dp)), - min_latency_mode=min_latency_mode, - ) + hidden_states = self.mlp( + hidden_states, + final_all_reduce_params=AllReduceParams(enable_allreduce=not ( + self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), + ) - if self.fusion_config.POST_MOE_FUSION: - if min_latency_mode: - shared_output = hidden_states[0] - hidden_states_activated_experts = hidden_states[1] - num_activated_experts_per_node = hidden_states[2] - experts_to_token_score = hidden_states[3] - - hidden_states, residual = self.moe_allreduce( - residual, - self.next_layer_layernorm.weight, - device_num_experts=num_activated_experts_per_node, - scale_input=experts_to_token_score, - active_experts_token_input=hidden_states_activated_experts, - token_input=shared_output, - eps=self.next_layer_layernorm.variance_epsilon, - ) - else: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - )) - elif self.fusion_config.POST_MLP_FUSION: + if self.fusion_config.POST_MLP_FUSION: hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -851,13 +897,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): else: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - # Fully Connected + + # MoE hidden_states = self.mlp( hidden_states, all_rank_num_tokens=spec_metadata.all_rank_num_tokens, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1 - or self.enable_attention_dp)), + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), ) if self.fusion_config.POST_MOE_FUSION: diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 0047fd4666..9c0d30e9b0 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -295,10 +295,10 @@ class Llama4MoE(nn.Module): self.aux_stream = aux_stream def compute_routed_output(self, hidden_states, all_rank_num_tokens, - min_latency_mode): + cutlass_min_latency_mode): router_logits = self.router(hidden_states) routed_output = self.experts(hidden_states, router_logits, - min_latency_mode) + cutlass_min_latency_mode) return routed_output def forward( @@ -306,16 +306,16 @@ class Llama4MoE(nn.Module): hidden_states: torch.Tensor, all_rank_num_tokens=None, final_all_reduce_params: Optional[AllReduceParams] = None, - min_latency_mode: Optional[bool] = False, + cutlass_min_latency_mode: Optional[bool] = False, ) -> torch.Tensor: # Only enable multi-stream for cuda graph since switch stream has extra host overhead # This design is mainly for low latency use case. Need to improve for max throughput use case. fn0 = lambda: self.shared_expert(hidden_states) fn1 = lambda: self.compute_routed_output( - hidden_states, all_rank_num_tokens, min_latency_mode) + hidden_states, all_rank_num_tokens, cutlass_min_latency_mode) shared_output, routed_output = maybe_execute_in_parallel( fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream) - if min_latency_mode: + if cutlass_min_latency_mode: return [shared_output, *routed_output] assert shared_output.size() == routed_output.size( @@ -414,12 +414,12 @@ class Llama4DecoderLayer(DecoderLayer): # TODO: Remove it after we fix crash on Hopper # major, minor = torch.cuda.get_device_capability() # is_blackwell = (major * 10 + minor) >= 100 - # min_latency_mode = hidden_states.size( + # cutlass_min_latency_mode = hidden_states.size( # 0 # ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted # Temporarily disable min-latency mode for Llama4 - min_latency_mode = False + cutlass_min_latency_mode = False if residual is None: residual = hidden_states @@ -456,7 +456,7 @@ class Llama4DecoderLayer(DecoderLayer): final_all_reduce_params=AllReduceParams(enable_allreduce=not ( self.fusion_config.POST_MOE_FUSION or self.fusion_config. POST_MLP_FUSION or self.mapping.tp_size == 1)), - min_latency_mode=min_latency_mode, + cutlass_min_latency_mode=cutlass_min_latency_mode, ) if spec_metadata is not None: # We save the hidden states in the spec metadata here. In _prepare_draft_tokens, @@ -467,7 +467,7 @@ class Llama4DecoderLayer(DecoderLayer): hidden_states, residual) if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION: - if min_latency_mode: + if cutlass_min_latency_mode: shared_output = hidden_states[0] hidden_states_activated_experts = hidden_states[1] num_activated_experts_per_node = hidden_states[2] diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 17cd9701a0..f16efe8caa 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -468,8 +468,11 @@ class MLA(nn.Module): self.mha.update_quant_config(self.quant_config) self.mqa.update_quant_config(self.quant_config) - has_fp8_block_scales = self.quant_config and self.quant_config.quant_mode.has_fp8_block_scales( - ) + # k_b_proj_trans's dtype must be consistent with self.kv_b_proj, + # which can be modified after __init__ + has_fp8_block_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype self.k_b_proj_trans = nn.Parameter( @@ -693,6 +696,7 @@ class MLA(nn.Module): dtype=q.dtype, device=q.device, ) + if self.k_b_proj_trans.dtype == torch.bfloat16: # [num_heads, num_tokens, self.qk_nope_head_dim] q_nope_t = q_nope.transpose(0, 1) diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index f2728fbf98..f316c66e94 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -265,7 +265,7 @@ class FusedMoE(nn.Module): equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter trtllm_gen backend (moe_backend="TRTLLM"): - min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used): + min-latency mode (cutlass_min_latency_mode flag of forward has no effect when trtllm_gen is used): dynamic quant + FusedMoe Op equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute @@ -689,7 +689,7 @@ class FusedMoE(nn.Module): self, x: Union[torch.Tensor, Fp4QuantizedTensor], router_logits: torch.Tensor, - min_latency_mode: bool = False, + cutlass_min_latency_mode: bool = False, output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens=None, ) -> torch.Tensor: @@ -766,7 +766,7 @@ class FusedMoE(nn.Module): x_sf = reswizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) - if self.smart_router and not min_latency_mode: + if self.smart_router and not cutlass_min_latency_mode: ep_size = self.cluster_size ep_rank = self.cluster_rank expert_start = ep_rank * self.num_experts // ep_size @@ -808,15 +808,15 @@ class FusedMoE(nn.Module): cluster_size=cluster_size, cluster_rank=cluster_rank, use_fp8_block_scaling=use_fp8_block_scaling, - min_latency_mode=min_latency_mode, + min_latency_mode=cutlass_min_latency_mode, ) - if min_latency_mode: + if cutlass_min_latency_mode: assert not self.reduce_results return final_hidden_states else: # Custom op requires all inputs are in the same type. - # Only in min_latency_mode, the output is a list of tensors. + # Only in cutlass_min_latency_mode, the output is a list of tensors. # Otherwise, the output should be unpacked as a single tensor. final_hidden_states = final_hidden_states[0] @@ -830,16 +830,17 @@ class FusedMoE(nn.Module): self, x: Union[torch.Tensor, Fp4QuantizedTensor], router_logits: torch.Tensor, - min_latency_mode: bool = False, + cutlass_min_latency_mode: bool = False, output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, ) -> torch.Tensor: """ - min_latency_mode has no effect when trtllm_gen backend is enabled. + cutlass_min_latency_mode has no effect when trtllm_gen backend is enabled. """ if self.is_cutlass(): - return self.forward_cutlass(x, router_logits, min_latency_mode, - output_dtype, all_rank_num_tokens) + return self.forward_cutlass(x, router_logits, + cutlass_min_latency_mode, output_dtype, + all_rank_num_tokens) elif self.is_trtllm(): return self.forward_trtllmgen(x, router_logits) else: @@ -851,7 +852,7 @@ class FusedMoE(nn.Module): self, x: Union[torch.Tensor, Fp4QuantizedTensor], router_logits: torch.Tensor, - min_latency_mode: bool = False, + cutlass_min_latency_mode: bool = False, output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, ) -> torch.Tensor: @@ -866,16 +867,16 @@ class FusedMoE(nn.Module): num_rows = x.shape[0] num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size - if min_latency_mode: + if cutlass_min_latency_mode: assert num_chunks == 1 and ( not self.reduce_results - ), "min_latency_mode must be used with a single chunk and reduce_results must be False" + ), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False" if num_chunks == 1: outputs = self.forward_chunk( x, router_logits, - min_latency_mode, + cutlass_min_latency_mode, output_dtype, all_rank_num_tokens=all_rank_num_tokens) outputs = self.reducescatter_or_allreduce(outputs) diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 748fd6fe55..d1775c4502 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -104,13 +104,12 @@ class GatedMLP(nn.Module): x: Union[torch.Tensor, Fp4QuantizedTensor], all_rank_num_tokens=None, final_all_reduce_params: Optional[AllReduceParams] = None, - min_latency_mode: Optional[bool] = False, lora_params: Optional[dict] = None, + **kwargs, ) -> torch.Tensor: if lora_params is not None: return self.forward_lora(x, all_rank_num_tokens, - final_all_reduce_params, min_latency_mode, - lora_params) + final_all_reduce_params, lora_params) if self.activation == F.silu: h1 = self.gate_up_proj(x) @@ -146,7 +145,6 @@ class GatedMLP(nn.Module): x: Union[torch.Tensor, Fp4QuantizedTensor], all_rank_num_tokens=None, final_all_reduce_params: Optional[AllReduceParams] = None, - min_latency_mode: Optional[bool] = False, lora_params: Optional[dict] = None, ) -> torch.Tensor: assert lora_params is not None