From 1770dd96d89650c81be2e2d64954cfc6f6dd57ce Mon Sep 17 00:00:00 2001 From: v-shobhit <161510941+v-shobhit@users.noreply.github.com> Date: Mon, 12 May 2025 22:54:37 -0700 Subject: [PATCH] Fix Pipeline Parallelism in Llama4 (#4106) Signed-off-by: Shobhit Verma --- tensorrt_llm/_torch/models/modeling_llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index e089128265..f882e641b8 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -440,7 +440,9 @@ class Llama4DecoderLayer(DecoderLayer): spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual) - if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION: + if (self.fusion_config.POST_MOE_FUSION + or self.fusion_config.POST_MLP_FUSION + ) and self.next_layer_layernorm is not None: if cutlass_min_latency_mode: shared_output = hidden_states[0] hidden_states_activated_experts = hidden_states[1] @@ -950,6 +952,7 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM): inputs_embeds: Optional[torch.FloatTensor] = None, return_context_logits: Optional[bool] = False, spec_metadata: Optional[SpecMetadata] = None, + pipeline_interface: Optional[PipelineInterface] = None, **kwargs, ) -> torch.Tensor: mm_embed = kwargs.get("multi_modal_data", []) @@ -960,7 +963,8 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM): position_ids, inputs_embeds, spec_metadata=spec_metadata, - return_context_logits=return_context_logits) + return_context_logits=return_context_logits, + pipeline_interface=pipeline_interface) return logits