[core] LTX Video 0.9.1 (#10330)
* update * make style * update * update * update * make style * single file related changes * update * fix * update single file urls and docs * update * fix
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX
|
||||
# LTX Video
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
Available models:
|
||||
|
||||
| Model name | Recommended dtype |
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
<!-- TODO(aryan): Update this when official weights are supported -->
|
||||
|
||||
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=768,
|
||||
height=512,
|
||||
num_frames=161,
|
||||
decode_timestep=0.03,
|
||||
decode_noise_scale=0.025,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"vae": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"model.diffusion_model": remove_keys_,
|
||||
}
|
||||
|
||||
VAE_091_SPECIAL_KEYS_REMAP = {
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
@@ -80,13 +105,16 @@ def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = ""
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
transformer = LTXVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -97,16 +125,21 @@ def convert_transformer(
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
||||
PREFIX_KEY = "vae."
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTXVideo(**config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_vae_config(version: str) -> Dict[str, Any]:
|
||||
if version == "0.9.0":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"timestep_conditioning": False,
|
||||
}
|
||||
elif version == "0.9.1":
|
||||
config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (128, 256, 512, 512),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 3, 3, 3, 4),
|
||||
"decoder_layers_per_block": (5, 6, 7, 8),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"scaling_factor": 1.0,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -139,6 +222,9 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
parser.add_argument(
|
||||
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -161,6 +247,7 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
output_path = Path(args.output_path)
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
@@ -169,13 +256,14 @@ if __name__ == "__main__":
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
config = get_vae_config(args.version)
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
|
||||
@@ -157,7 +157,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
@@ -605,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
@@ -2338,12 +2342,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_091_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
|
||||
@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
@@ -109,7 +110,9 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
elementwise_affine: bool = False,
|
||||
non_linearity: str = "swish",
|
||||
is_causal: bool = True,
|
||||
):
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
@@ -135,18 +138,54 @@ class LTXVideoResnetBlock3d(nn.Module):
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
self.per_channel_scale1 = None
|
||||
self.per_channel_scale2 = None
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
|
||||
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale_1) + shift_1
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.per_channel_scale1 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
|
||||
|
||||
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.scale_shift_table is not None:
|
||||
hidden_states = hidden_states * (1 + scale_2) + shift_2
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.per_channel_scale2 is not None:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
spatial_noise = torch.randn(
|
||||
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)[None]
|
||||
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
|
||||
|
||||
if self.norm3 is not None:
|
||||
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
@@ -163,12 +202,16 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
is_causal: bool = True,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = in_channels * stride[0] * stride[1] * stride[2]
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -181,6 +224,15 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.residual:
|
||||
residual = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
)
|
||||
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
|
||||
residual = residual.repeat(1, repeats, 1, 1, 1)
|
||||
residual = residual[:, :, self.stride[0] - 1 :]
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
|
||||
@@ -188,6 +240,9 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
|
||||
|
||||
if self.residual:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -273,7 +328,12 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXDownBlock3D` class."""
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -285,16 +345,18 @@ class LTXVideoDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
if self.conv_out is not None:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -329,9 +391,15 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
@@ -342,15 +410,32 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `LTXMidBlock3D` class."""
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -360,9 +445,11 @@ class LTXVideoMidBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -403,11 +490,19 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
is_causal: bool = True,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.time_embedder = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
|
||||
|
||||
self.conv_in = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_in = LTXVideoResnetBlock3d(
|
||||
@@ -417,11 +512,23 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
is_causal=is_causal,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
@@ -433,15 +540,32 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
eps=resnet_eps,
|
||||
non_linearity=resnet_act_fn,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.conv_in is not None:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states, temb, generator)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -456,9 +580,11 @@ class LTXVideoUpBlock3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, generator
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb, generator)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -623,6 +749,8 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
Epsilon value for ResNet normalization layers.
|
||||
is_causal (`bool`, defaults to `False`):
|
||||
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
||||
timestep_conditioning (`bool`, defaults to `False`):
|
||||
Whether to condition the model on timesteps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -636,6 +764,10 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: Tuple[bool, ...] = (False, False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -646,6 +778,9 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
block_out_channels = tuple(reversed(block_out_channels))
|
||||
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
|
||||
layers_per_block = tuple(reversed(layers_per_block))
|
||||
inject_noise = tuple(reversed(inject_noise))
|
||||
upsample_residual = tuple(reversed(upsample_residual))
|
||||
upsample_factor = tuple(reversed(upsample_factor))
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
@@ -653,15 +788,20 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
)
|
||||
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[0],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[0],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
|
||||
# up blocks
|
||||
num_block_out_channels = len(block_out_channels)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i in range(num_block_out_channels):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
input_channel = output_channel // upsample_factor[i]
|
||||
output_channel = block_out_channels[i] // upsample_factor[i]
|
||||
|
||||
up_block = LTXVideoUpBlock3d(
|
||||
in_channels=input_channel,
|
||||
@@ -670,6 +810,10 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
is_causal=is_causal,
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
upscale_factor=upsample_factor[i],
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
@@ -681,9 +825,16 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
# timestep embedding
|
||||
self.time_embedder = None
|
||||
self.scale_shift_table = None
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -694,17 +845,33 @@ class LTXVideoDecoder3d(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb
|
||||
)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
hidden_states = self.mid_block(hidden_states, temb)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
hidden_states = up_block(hidden_states, temb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.time_embedder is not None:
|
||||
temb = self.time_embedder(
|
||||
timestep=temb.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=hidden_states.size(0),
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
|
||||
temb = temb + self.scale_shift_table[None, ..., None, None, None]
|
||||
shift, scale = temb.unbind(dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
@@ -767,8 +934,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 128,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
||||
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
||||
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
|
||||
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
|
||||
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
@@ -792,13 +966,17 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.decoder = LTXVideoDecoder3d(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
spatio_temporal_scaling=spatio_temporal_scaling,
|
||||
layers_per_block=layers_per_block,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=decoder_causal,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
inject_noise=decoder_inject_noise,
|
||||
upsample_residual=upsample_residual,
|
||||
upsample_factor=upsample_factor,
|
||||
)
|
||||
|
||||
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
|
||||
@@ -937,13 +1115,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def _decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
||||
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
return self.tiled_decode(z, temb, return_dict=return_dict)
|
||||
|
||||
if self.use_framewise_decoding:
|
||||
# TODO(aryan): requires investigation
|
||||
@@ -953,7 +1133,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
dec = self.decoder(z)
|
||||
dec = self.decoder(z, temb)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
@@ -961,7 +1141,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -976,10 +1158,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
if temb is not None:
|
||||
decoded_slices = [
|
||||
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
|
||||
]
|
||||
else:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
decoded = self._decode(z, temb).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
@@ -1061,7 +1248,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def tiled_decode(
|
||||
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
@@ -1102,7 +1291,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
|
||||
)
|
||||
else:
|
||||
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
|
||||
time = self.decoder(
|
||||
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
|
||||
)
|
||||
|
||||
row.append(time)
|
||||
rows.append(row)
|
||||
@@ -1130,6 +1321,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
@@ -1140,7 +1332,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z, temb)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
|
||||
@@ -511,6 +511,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -563,6 +565,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -753,7 +759,25 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -571,6 +571,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -625,6 +627,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -849,7 +855,25 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
|
||||
@@ -52,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
transformer_cls = LTXVideoTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (8, 8, 8, 8),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, False, False),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (1, 1, 1, 1),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, False),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (True, True, True, False),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": True,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
|
||||
|
||||
return {"sample": image, "temb": timestep}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTXVideoEncoder3d",
|
||||
"LTXVideoDecoder3d",
|
||||
"LTXVideoDownBlock3D",
|
||||
"LTXVideoMidBlock3d",
|
||||
"LTXVideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
@@ -63,10 +63,19 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
@@ -68,10 +68,19 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTXVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=8,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=(1, 1, 1, 1, 1),
|
||||
decoder_layers_per_block=(1, 1, 1, 1, 1),
|
||||
spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_spatio_temporal_scaling=(True, True, False, False),
|
||||
decoder_inject_noise=(False, False, False, False, False),
|
||||
upsample_residual=(False, False, False, False),
|
||||
upsample_factor=(1, 1, 1, 1),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
|
||||
Reference in New Issue
Block a user