Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f6a41ed02 | |||
| 91fbe16b63 |
@@ -117,6 +117,8 @@ class CogVideoXCausalConv3d(nn.Module):
|
|||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.return_conv_cache = True
|
||||||
|
|
||||||
def fake_context_parallel_forward(
|
def fake_context_parallel_forward(
|
||||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -128,7 +130,10 @@ class CogVideoXCausalConv3d(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
if self.return_conv_cache:
|
||||||
|
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||||
|
else:
|
||||||
|
conv_cache = None
|
||||||
|
|
||||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||||
@@ -1079,6 +1084,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
|
|
||||||
self.use_slicing = False
|
self.use_slicing = False
|
||||||
self.use_tiling = False
|
self.use_tiling = False
|
||||||
|
self.use_framewise_batching = True
|
||||||
|
|
||||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||||
@@ -1174,6 +1180,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
"""
|
"""
|
||||||
self.use_slicing = False
|
self.use_slicing = False
|
||||||
|
|
||||||
|
def enable_framewise_batching(self) -> None:
|
||||||
|
self.use_framewise_batching = True
|
||||||
|
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, CogVideoXCausalConv3d):
|
||||||
|
module.return_conv_cache = True
|
||||||
|
|
||||||
|
def disable_framewise_batching(self) -> None:
|
||||||
|
self.use_framewise_batching = False
|
||||||
|
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, CogVideoXCausalConv3d):
|
||||||
|
module.return_conv_cache = False
|
||||||
|
|
||||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, num_channels, num_frames, height, width = x.shape
|
batch_size, num_channels, num_frames, height, width = x.shape
|
||||||
|
|
||||||
@@ -1184,19 +1204,26 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||||
conv_cache = None
|
conv_cache = None
|
||||||
enc = []
|
|
||||||
|
|
||||||
for i in range(num_batches):
|
if self.use_framewise_batching:
|
||||||
remaining_frames = num_frames % frame_batch_size
|
enc = []
|
||||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
|
||||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
for i in range(num_batches):
|
||||||
x_intermediate = x[:, :, start_frame:end_frame]
|
remaining_frames = num_frames % frame_batch_size
|
||||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||||
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||||
|
x_intermediate = x[:, :, start_frame:end_frame]
|
||||||
|
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||||
|
if self.quant_conv is not None:
|
||||||
|
x_intermediate = self.quant_conv(x_intermediate)
|
||||||
|
enc.append(x_intermediate)
|
||||||
|
|
||||||
|
enc = torch.cat(enc, dim=2)
|
||||||
|
else:
|
||||||
|
enc, _ = self.encoder(x, conv_cache=conv_cache)
|
||||||
if self.quant_conv is not None:
|
if self.quant_conv is not None:
|
||||||
x_intermediate = self.quant_conv(x_intermediate)
|
enc = self.quant_conv(enc)
|
||||||
enc.append(x_intermediate)
|
|
||||||
|
|
||||||
enc = torch.cat(enc, dim=2)
|
|
||||||
return enc
|
return enc
|
||||||
|
|
||||||
@apply_forward_hook
|
@apply_forward_hook
|
||||||
@@ -1236,19 +1263,25 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|||||||
frame_batch_size = self.num_latent_frames_batch_size
|
frame_batch_size = self.num_latent_frames_batch_size
|
||||||
num_batches = max(num_frames // frame_batch_size, 1)
|
num_batches = max(num_frames // frame_batch_size, 1)
|
||||||
conv_cache = None
|
conv_cache = None
|
||||||
dec = []
|
|
||||||
|
|
||||||
for i in range(num_batches):
|
if self.use_framewise_batching:
|
||||||
remaining_frames = num_frames % frame_batch_size
|
dec = []
|
||||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
|
||||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
for i in range(num_batches):
|
||||||
z_intermediate = z[:, :, start_frame:end_frame]
|
remaining_frames = num_frames % frame_batch_size
|
||||||
|
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||||
|
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||||
|
z_intermediate = z[:, :, start_frame:end_frame]
|
||||||
|
if self.post_quant_conv is not None:
|
||||||
|
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||||
|
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||||
|
dec.append(z_intermediate)
|
||||||
|
|
||||||
|
dec = torch.cat(dec, dim=2)
|
||||||
|
else:
|
||||||
if self.post_quant_conv is not None:
|
if self.post_quant_conv is not None:
|
||||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
dec = self.post_quant_conv(z)
|
||||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
dec, _ = self.decoder(z, conv_cache=conv_cache)
|
||||||
dec.append(z_intermediate)
|
|
||||||
|
|
||||||
dec = torch.cat(dec, dim=2)
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (dec,)
|
return (dec,)
|
||||||
|
|||||||
Reference in New Issue
Block a user