[None][fix] Complete the last missing allreduce op in Llama3/4. (#6850)

The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-08-15 09:07:09 +08:00 committed by Jonas Yang CN
parent b821883b25
commit e106045fda

View File

@ -554,50 +554,60 @@ class Llama4DecoderLayer(DecoderLayer):
hidden_states, residual)
if (self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
) and self.next_layer_layernorm is not None:
# Get the scale for the next allreduce fusion op
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
# Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None
# TODO: MIN_LATENCY_MODE is hardcoded to False
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
allreduce_output = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
allreduce_output = self.all_reduce(
or self.fusion_config.POST_MLP_FUSION):
# If there is no extra layernorm, do another pure allreduce because
# the allreduce in feed-forward module has been disabled.
if self.next_layer_layernorm is None:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_feed_forward_fusion_op,
fusion_op=None,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
# Unpack the allreduce output
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = allreduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = allreduce_output
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None
# TODO: MIN_LATENCY_MODE is hardcoded to False
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]
allreduce_output = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=
hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
allreduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_feed_forward_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
# Unpack the allreduce output
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = allreduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = allreduce_output
elif self.next_layer_layernorm:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
@ -710,6 +720,7 @@ class LlamaDecoderLayer(DecoderLayer):
scale = self.mlp.gate_up_proj.input_scale
else:
scale = None
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
@ -752,25 +763,40 @@ class LlamaDecoderLayer(DecoderLayer):
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.POST_MLP_FUSION and self.next_attn is not None:
if self.is_nvfp4 or self.is_fp8_quant:
scale = self.next_attn.qkv_proj.input_scale
if self.POST_MLP_FUSION:
# If there is no extra layernorm, do another pure allreduce.
if self.next_layer_layernorm is None:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=None,
residual=residual,
))
else:
scale = None
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_mlp_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
if self.is_nvfp4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = all_reduce_output
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_mlp_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = all_reduce_output
elif self.next_layer_layernorm:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)