diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index d995025952..ed65bc5451 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -992,7 +992,9 @@ class LTXVideoDecoder3d(nn.Module): # timestep embedding self.time_embedder = None self.scale_shift_table = None + self.timestep_scale_multiplier = None if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) @@ -1001,6 +1003,9 @@ class LTXVideoDecoder3d(nn.Module): def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)