mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Complete the last missing allreduce op in Llama3/4. (#6850)
The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
b821883b25
commit
e106045fda
@ -554,50 +554,60 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual)
|
||||
|
||||
if (self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
) and self.next_layer_layernorm is not None:
|
||||
# Get the scale for the next allreduce fusion op
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
# Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
|
||||
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
# TODO: MIN_LATENCY_MODE is hardcoded to False
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
allreduce_output = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
allreduce_output = self.all_reduce(
|
||||
or self.fusion_config.POST_MLP_FUSION):
|
||||
# If there is no extra layernorm, do another pure allreduce because
|
||||
# the allreduce in feed-forward module has been disabled.
|
||||
if self.next_layer_layernorm is None:
|
||||
hidden_states, residual = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_feed_forward_fusion_op,
|
||||
fusion_op=None,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
|
||||
# Unpack the allreduce output
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = allreduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = allreduce_output
|
||||
# The next layernorm exists but it could be the last decoder layer.
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
# TODO: MIN_LATENCY_MODE is hardcoded to False
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
allreduce_output = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=
|
||||
hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
allreduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_feed_forward_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
|
||||
# Unpack the allreduce output
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = allreduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = allreduce_output
|
||||
elif self.next_layer_layernorm:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -710,6 +720,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
scale = self.mlp.gate_up_proj.input_scale
|
||||
else:
|
||||
scale = None
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
@ -752,25 +763,40 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
if self.POST_MLP_FUSION and self.next_attn is not None:
|
||||
if self.is_nvfp4 or self.is_fp8_quant:
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
|
||||
if self.POST_MLP_FUSION:
|
||||
# If there is no extra layernorm, do another pure allreduce.
|
||||
if self.next_layer_layernorm is None:
|
||||
hidden_states, residual = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=None,
|
||||
residual=residual,
|
||||
))
|
||||
else:
|
||||
scale = None
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_mlp_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
# The next layernorm exists but it could be the last decoder layer.
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_mlp_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
elif self.next_layer_layernorm:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user