mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix Pipeline Parallelism in Llama4 (#4106)
Signed-off-by: Shobhit Verma <shobhitv@nvidia.com>
This commit is contained in:
parent
13c8e5a8a8
commit
1770dd96d8
@ -440,7 +440,9 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION or self.fusion_config.POST_MLP_FUSION:
|
||||
if (self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
) and self.next_layer_layernorm is not None:
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
@ -950,6 +952,7 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: Optional[bool] = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
pipeline_interface: Optional[PipelineInterface] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
mm_embed = kwargs.get("multi_modal_data", [])
|
||||
@ -960,7 +963,8 @@ class Llama4ForConditionalGeneration(Llama4ForCausalLM):
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
spec_metadata=spec_metadata,
|
||||
return_context_logits=return_context_logits)
|
||||
return_context_logits=return_context_logits,
|
||||
pipeline_interface=pipeline_interface)
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user