diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index e089128265..f882e641b8 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -440,7 +440,9 @@ class Llama4DecoderLayer(DecoderLayer): spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual) - if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION: + if (self.fusion_config.POST_MOE_FUSION + or self.fusion_config.POST_MLP_FUSION + ) and self.next_layer_layernorm is not None: if cutlass_min_latency_mode: shared_output = hidden_states[0] hidden_states_activated_experts = hidden_states[1] @@ -950,6 +952,7 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM): inputs_embeds: Optional[torch.FloatTensor] = None, return_context_logits: Optional[bool] = False, spec_metadata: Optional[SpecMetadata] = None, + pipeline_interface: Optional[PipelineInterface] = None, **kwargs, ) -> torch.Tensor: mm_embed = kwargs.get("multi_modal_data", []) @@ -960,7 +963,8 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM): position_ids, inputs_embeds, spec_metadata=spec_metadata, - return_context_logits=return_context_logits) + return_context_logits=return_context_logits, + pipeline_interface=pipeline_interface) return logits