[Bugfix] Fix Step3 pipeline parallel KeyError for residual tensor (#37622)

Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Joaquín Mondéjar
2026-05-29 12:04:02 +02:00
committed by GitHub
parent 7ebc0ec104
commit 60a7a2214f
+1 -1
View File
@@ -345,7 +345,7 @@ class Step3TextModel(nn.Module):
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: