[Deepseek] Refactor Deepseek Decoder layer (#4016)

Refactor Deepseek Decoder layer

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
Co-authored-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
hlu1 2025-05-07 10:43:10 -07:00 committed by GitHub
parent bb766eca0a
commit 26a2679217
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 165 additions and 115 deletions

View File

@ -455,7 +455,7 @@ class Deepseekv3MoE(nn.Module):
return True
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, min_latency_mode):
all_rank_num_tokens, cutlass_min_latency_mode):
# max-throughput
if self.use_dp and self.mapping.tp_size > 1 and not self.enable_alltoall:
max_num_token = max(all_rank_num_tokens)
@ -473,7 +473,7 @@ class Deepseekv3MoE(nn.Module):
routed_output = self.experts(hidden_states_fp4 or hidden_states,
router_logits,
min_latency_mode,
cutlass_min_latency_mode,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens)
@ -485,9 +485,9 @@ class Deepseekv3MoE(nn.Module):
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
all_rank_num_tokens: Optional[list[int]] = None,
final_all_reduce_params: Optional[AllReduceParams] = None,
min_latency_mode: Optional[bool] = False,
cutlass_min_latency_mode: Optional[bool] = False,
) -> torch.Tensor:
if min_latency_mode:
if cutlass_min_latency_mode:
assert not self.use_dp
def _compute_shared_output():
@ -498,10 +498,9 @@ class Deepseekv3MoE(nn.Module):
return shared_output
def _compute_routed_output():
routed_output = self.compute_routed_output(hidden_states,
hidden_states_fp4,
all_rank_num_tokens,
min_latency_mode)
routed_output = self.compute_routed_output(
hidden_states, hidden_states_fp4, all_rank_num_tokens,
cutlass_min_latency_mode)
return routed_output
shared_output, routed_output = maybe_execute_in_parallel(
@ -509,7 +508,7 @@ class Deepseekv3MoE(nn.Module):
self.event_dict[EventType.Main],
self.event_dict[EventType.MoeShared], self.aux_stream)
if min_latency_mode:
if cutlass_min_latency_mode:
return [shared_output, *routed_output]
else:
assert shared_output.size() == routed_output.size(
@ -531,6 +530,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
@ -544,16 +544,17 @@ class DeepseekV3DecoderLayer(DecoderLayer):
model_config,
layer_idx=layer_idx,
aux_stream=aux_stream_dict[AuxStreamType.Attention])
self.fusion_config = EagerFusionConfig()
self.enable_attention_dp = mapping.enable_attention_dp
self.mlp_tp_size = mapping.tp_size
pp_layer_offset = mapping.pp_layers(config.num_hidden_layers)[0]
global_layer_idx = pp_layer_offset + layer_idx
enable_fusion = os.environ.get("TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED",
"0") == "0"
self.enable_fusion = enable_fusion and not self.enable_attention_dp
self.fusion_config = EagerFusionConfig()
self.enable_fusion = os.environ.get(
"TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0"
self.enable_fusion &= not self.enable_attention_dp
# FIXME: incompatible with mixed quantization mode (including excluding modules from quantization)
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
@ -584,8 +585,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.mlp_tp_size = self._compute_mlp_tp_size(
config.intermediate_size, block_size)
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp
has_mlp_tp = self.mlp_tp_size > 1
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp and not has_pp
self.mlp = GatedMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@ -643,8 +645,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
)
return mlp_tp_size
def _enable_latency_mode(self, num_tokens: int):
return num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS'
def _enable_min_latency_mode(self, num_tokens: int):
return (num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION
and self.is_nvfp4
and self.model_config.moe_backend == 'CUTLASS')
def forward(
self,
@ -668,24 +672,78 @@ class DeepseekV3DecoderLayer(DecoderLayer):
**kwargs,
)
min_latency_mode = self._enable_latency_mode(hidden_states.size(0))
if isinstance(self.mlp, Deepseekv3MoE):
return self.forward_MoE(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
else:
assert isinstance(self.mlp, GatedMLP)
return self.forward_mlp(
hidden_states=hidden_states,
residual=residual,
)
hidden_states_fp4 = None
if self.fusion_config.PRE_MOE_FUSION:
if min_latency_mode:
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.experts.fc31_input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
))
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
hidden_states_sf)
else:
def forward_MoE(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
) -> torch.Tensor:
def _run_MoE(hidden_states, hidden_states_fp4):
return self.mlp(
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
cutlass_min_latency_mode=cutlass_min_latency_mode,
)
cutlass_min_latency_mode = self._enable_min_latency_mode(
hidden_states.shape[0])
if cutlass_min_latency_mode:
assert self.fusion_config.PRE_MOE_FUSION and self.fusion_config.POST_MOE_FUSION
assert self.model_config.moe_backend == 'CUTLASS'
hidden_states, hidden_states_act, hidden_states_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.
RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
residual=residual,
norm_weight=self.post_attention_layernorm.weight,
scale=self.mlp.experts.fc31_input_scale,
eps=self.post_attention_layernorm.variance_epsilon,
))
hidden_states_fp4 = Fp4QuantizedTensor(hidden_states_act,
hidden_states_sf)
hidden_states = _run_MoE(hidden_states, hidden_states_fp4)
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
# MoE_finalize is fused into allreduce
hidden_states, residual = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
if self.fusion_config.PRE_MOE_FUSION:
# moe_backend can be either CUTLASS or TRTLLM here
# TODO: unify the two min-latency MoE backends by enabling quant fusion
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -694,7 +752,36 @@ class DeepseekV3DecoderLayer(DecoderLayer):
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
))
elif self.fusion_config.PRE_MLP_FUSION:
else:
# No fusion
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = _run_MoE(hidden_states, hidden_states_fp4=None)
if self.fusion_config.POST_MOE_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
else:
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
return hidden_states, residual
def forward_mlp(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
if self.fusion_config.PRE_MLP_FUSION:
act_fp4, act_sf, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -710,54 +797,13 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.fusion_config.PRE_MOE_FUSION and min_latency_mode:
hidden_states = self.mlp(
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
else:
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
hidden_states = self.mlp(
hidden_states,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)),
)
if self.fusion_config.POST_MOE_FUSION:
if min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
hidden_states, residual = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
eps=self.next_layer_layernorm.variance_epsilon,
))
elif self.fusion_config.POST_MLP_FUSION:
if self.fusion_config.POST_MLP_FUSION:
hidden_states, residual = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -851,13 +897,14 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# Fully Connected
# MoE
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
or self.enable_attention_dp)),
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
)
if self.fusion_config.POST_MOE_FUSION:

View File

@ -295,10 +295,10 @@ class Llama4MoE(nn.Module):
self.aux_stream = aux_stream
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
min_latency_mode):
cutlass_min_latency_mode):
router_logits = self.router(hidden_states)
routed_output = self.experts(hidden_states, router_logits,
min_latency_mode)
cutlass_min_latency_mode)
return routed_output
def forward(
@ -306,16 +306,16 @@ class Llama4MoE(nn.Module):
hidden_states: torch.Tensor,
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
min_latency_mode: Optional[bool] = False,
cutlass_min_latency_mode: Optional[bool] = False,
) -> torch.Tensor:
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
# This design is mainly for low latency use case. Need to improve for max throughput use case.
fn0 = lambda: self.shared_expert(hidden_states)
fn1 = lambda: self.compute_routed_output(
hidden_states, all_rank_num_tokens, min_latency_mode)
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
shared_output, routed_output = maybe_execute_in_parallel(
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
if min_latency_mode:
if cutlass_min_latency_mode:
return [shared_output, *routed_output]
assert shared_output.size() == routed_output.size(
@ -414,12 +414,12 @@ class Llama4DecoderLayer(DecoderLayer):
# TODO: Remove it after we fix crash on Hopper
# major, minor = torch.cuda.get_device_capability()
# is_blackwell = (major * 10 + minor) >= 100
# min_latency_mode = hidden_states.size(
# cutlass_min_latency_mode = hidden_states.size(
# 0
# ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
# Temporarily disable min-latency mode for Llama4
min_latency_mode = False
cutlass_min_latency_mode = False
if residual is None:
residual = hidden_states
@ -456,7 +456,7 @@ class Llama4DecoderLayer(DecoderLayer):
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.mapping.tp_size == 1)),
min_latency_mode=min_latency_mode,
cutlass_min_latency_mode=cutlass_min_latency_mode,
)
if spec_metadata is not None:
# We save the hidden states in the spec metadata here. In _prepare_draft_tokens,
@ -467,7 +467,7 @@ class Llama4DecoderLayer(DecoderLayer):
hidden_states, residual)
if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION:
if min_latency_mode:
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]

View File

@ -468,8 +468,11 @@ class MLA(nn.Module):
self.mha.update_quant_config(self.quant_config)
self.mqa.update_quant_config(self.quant_config)
has_fp8_block_scales = self.quant_config and self.quant_config.quant_mode.has_fp8_block_scales(
)
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
# which can be modified after __init__
has_fp8_block_scales = (
self.kv_b_proj.quant_config
and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype
self.k_b_proj_trans = nn.Parameter(
@ -693,6 +696,7 @@ class MLA(nn.Module):
dtype=q.dtype,
device=q.device,
)
if self.k_b_proj_trans.dtype == torch.bfloat16:
# [num_heads, num_tokens, self.qk_nope_head_dim]
q_nope_t = q_nope.transpose(0, 1)

View File

@ -265,7 +265,7 @@ class FusedMoE(nn.Module):
equals to: dynamic quant + routing(topK, etc.) [+ fp4_allgather] + scatter + gemm1 + swiglu + gemm2 + finalizeMoeRoute [no allreduce] + reducescatter
trtllm_gen backend (moe_backend="TRTLLM"):
min-latency mode (min_latency_mode flag of forward has no effect when trtllm_gen is used):
min-latency mode (cutlass_min_latency_mode flag of forward has no effect when trtllm_gen is used):
dynamic quant + FusedMoe Op
equals to: dynamic quant + routing(topK, etc.) + scatter + gemm1 + swiglu + gemm2 + finalize MoeRoute
@ -689,7 +689,7 @@ class FusedMoE(nn.Module):
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
min_latency_mode: bool = False,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens=None,
) -> torch.Tensor:
@ -766,7 +766,7 @@ class FusedMoE(nn.Module):
x_sf = reswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)
if self.smart_router and not min_latency_mode:
if self.smart_router and not cutlass_min_latency_mode:
ep_size = self.cluster_size
ep_rank = self.cluster_rank
expert_start = ep_rank * self.num_experts // ep_size
@ -808,15 +808,15 @@ class FusedMoE(nn.Module):
cluster_size=cluster_size,
cluster_rank=cluster_rank,
use_fp8_block_scaling=use_fp8_block_scaling,
min_latency_mode=min_latency_mode,
min_latency_mode=cutlass_min_latency_mode,
)
if min_latency_mode:
if cutlass_min_latency_mode:
assert not self.reduce_results
return final_hidden_states
else:
# Custom op requires all inputs are in the same type.
# Only in min_latency_mode, the output is a list of tensors.
# Only in cutlass_min_latency_mode, the output is a list of tensors.
# Otherwise, the output should be unpacked as a single tensor.
final_hidden_states = final_hidden_states[0]
@ -830,16 +830,17 @@ class FusedMoE(nn.Module):
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
min_latency_mode: bool = False,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
) -> torch.Tensor:
"""
min_latency_mode has no effect when trtllm_gen backend is enabled.
cutlass_min_latency_mode has no effect when trtllm_gen backend is enabled.
"""
if self.is_cutlass():
return self.forward_cutlass(x, router_logits, min_latency_mode,
output_dtype, all_rank_num_tokens)
return self.forward_cutlass(x, router_logits,
cutlass_min_latency_mode, output_dtype,
all_rank_num_tokens)
elif self.is_trtllm():
return self.forward_trtllmgen(x, router_logits)
else:
@ -851,7 +852,7 @@ class FusedMoE(nn.Module):
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
min_latency_mode: bool = False,
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
) -> torch.Tensor:
@ -866,16 +867,16 @@ class FusedMoE(nn.Module):
num_rows = x.shape[0]
num_chunks = (num_rows + max_chunk_size - 1) // max_chunk_size
if min_latency_mode:
if cutlass_min_latency_mode:
assert num_chunks == 1 and (
not self.reduce_results
), "min_latency_mode must be used with a single chunk and reduce_results must be False"
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
if num_chunks == 1:
outputs = self.forward_chunk(
x,
router_logits,
min_latency_mode,
cutlass_min_latency_mode,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens)
outputs = self.reducescatter_or_allreduce(outputs)

View File

@ -104,13 +104,12 @@ class GatedMLP(nn.Module):
x: Union[torch.Tensor, Fp4QuantizedTensor],
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
min_latency_mode: Optional[bool] = False,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if lora_params is not None:
return self.forward_lora(x, all_rank_num_tokens,
final_all_reduce_params, min_latency_mode,
lora_params)
final_all_reduce_params, lora_params)
if self.activation == F.silu:
h1 = self.gate_up_proj(x)
@ -146,7 +145,6 @@ class GatedMLP(nn.Module):
x: Union[torch.Tensor, Fp4QuantizedTensor],
all_rank_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
min_latency_mode: Optional[bool] = False,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
assert lora_params is not None