vae fix
This commit is contained in:
@@ -992,7 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
self.timestep_scale_multiplier = None
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
@@ -1001,6 +1003,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if self.timestep_scale_multiplier is not None:
|
||||
temb = temb * self.timestep_scale_multiplier
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user