Fix Pipeline Parallelism in Llama4 (#4106)

Signed-off-by: Shobhit Verma <shobhitv@nvidia.com>
This commit is contained in:
v-shobhit 2025-05-12 22:54:37 -07:00 committed by GitHub
parent 13c8e5a8a8
commit 1770dd96d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -440,7 +440,9 @@ class Llama4DecoderLayer(DecoderLayer):
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION:
if (self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
) and self.next_layer_layernorm is not None:
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
@ -950,6 +952,7 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM):
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: Optional[bool] = False,
spec_metadata: Optional[SpecMetadata] = None,
pipeline_interface: Optional[PipelineInterface] = None,
**kwargs,
) -> torch.Tensor:
mm_embed = kwargs.get("multi_modal_data", [])
@ -960,7 +963,8 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM):
position_ids,
inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits)
return_context_logits=return_context_logits,
pipeline_interface=pipeline_interface)
return logits