Compare commits

...

10 Commits

Author SHA1 Message Date
YiYi Xu 8b9bfaea80 Release v0.30.1 2024-08-23 15:24:29 -10:00
Dhruv Nair b12c7f8390 [Single File] Support loading Comfy UI Flux checkpoints (#9243)
update
2024-08-23 15:19:50 -10:00
zR 06f36713ae Cogvideox-5B Model adapter change (#9203)
* draft of embedding

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-08-23 15:17:20 -10:00
Aryan 19c5d7b376 [tests] fix broken xformers tests (#9206)
* fix xformers tests

* remove unnecessary modifications to cogvideox tests

* update
2024-08-23 15:16:58 -10:00
Sayak Paul 99a64aa63c [Flux LoRA] support parsing alpha from a flux lora state dict. (#9236)
* support parsing alpha from a flux lora state dict.

* conditional import.

* fix breaking changes.

* safeguard alpha.

* fix
2024-08-23 15:11:29 -10:00
Dhruv Nair 1bb419672d [Single File] Fix configuring scheduler via legacy kwargs (#9229)
update
2024-08-23 15:11:06 -10:00
Simo Ryu a655574710 Add Learned PE selection for Auraflow (#9182)
* add pe

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

* beauty

* retrigger ci.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-23 15:10:13 -10:00
Aryan 67a80dfbd5 [refactor] CogVideoX followups + tiled decoding support (#9150)
* refactor context parallel cache; update torch compile time benchmark

* add tiling support

* make style

* remove num_frames % 8 == 0 requirement

* update default num_frames to original value

* add explanations + refactor

* update torch compile example

* update docs

* update

* clean up if-statements

* address review comments

* add test for vae tiling

* update docs

* update docs

* update docstrings

* add modeling test for cogvideox transformer

* make style
2024-08-23 15:09:38 -10:00
Dhruv Nair 1f77300d23 Update Video Loading/Export to use imageio (#9094)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-23 15:09:10 -10:00
sayakpaul 8a79d8ec39 Release: v0.30.0 2024-08-07 13:00:43 +05:30
59 changed files with 1364 additions and 283 deletions
+24 -23
View File
@@ -15,9 +15,7 @@
# CogVideoX # CogVideoX
<!-- TODO: update paper with ArXiv link when ready. --> [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
The abstract from the paper is: The abstract from the paper is:
@@ -31,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
There are two models available that can be used with the CogVideoX pipeline:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
## Inference ## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -43,43 +45,42 @@ from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
``` ```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`: Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
```python ```python
pipeline.transformer.to(memory_format=torch.channels_last) pipe.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
``` ```
Finally, compile the components and run inference: Finally, compile the components and run inference:
```python ```python
pipeline.transformer = torch.compile(pipeline.transformer) pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works very well with long and well-described prompts # CogVideoX works well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
``` ```
The [benchmark](TODO: link) results on an 80GB A100 machine are: The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
``` ```
Without torch.compile(): Average inference time: TODO seconds. Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: TODO seconds. With torch.compile(): Average inference time: 76.27 seconds.
``` ```
### Memory optimization
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
## CogVideoXPipeline ## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline [[autodoc]] CogVideoXPipeline
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
class MarigoldDepthOutput(BaseOutput): class MarigoldDepthOutput(BaseOutput):
""" """
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
+1 -1
View File
@@ -64,7 +64,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0") check_min_version("0.30.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
+55 -9
View File
@@ -86,6 +86,9 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
"key_layernorm_list": reassign_query_key_layernorm_inplace, "key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace, "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace, "embed_tokens": remove_keys_inplace,
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
} }
VAE_KEYS_RENAME_DICT = { VAE_KEYS_RENAME_DICT = {
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(ckpt_path: str): def convert_transformer(
ckpt_path: str,
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel() transformer = CogVideoXTransformer3DModel(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :] new_key = key[len(PREFIX_KEY) :]
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
return transformer return transformer
def convert_vae(ckpt_path: str): def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX() vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
@@ -172,13 +185,26 @@ def get_args():
) )
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument( parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
) )
parser.add_argument( parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
) )
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
)
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args() return parser.parse_args()
@@ -188,18 +214,33 @@ if __name__ == "__main__":
transformer = None transformer = None
vae = None vae = None
if args.fp16 and args.bf16:
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
if args.transformer_ckpt_path is not None: if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path) transformer = convert_transformer(
args.transformer_ckpt_path,
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
dtype,
)
if args.vae_ckpt_path is not None: if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path) vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl" text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config( scheduler = CogVideoXDDIMScheduler.from_config(
{ {
"snr_shift_scale": 3.0, "snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012, "beta_end": 0.012,
"beta_schedule": "scaled_linear", "beta_schedule": "scaled_linear",
"beta_start": 0.00085, "beta_start": 0.00085,
@@ -208,7 +249,7 @@ if __name__ == "__main__":
"prediction_type": "v_prediction", "prediction_type": "v_prediction",
"rescale_betas_zero_snr": True, "rescale_betas_zero_snr": True,
"set_alpha_to_one": True, "set_alpha_to_one": True,
"timestep_spacing": "linspace", "timestep_spacing": "trailing",
} }
) )
@@ -218,5 +259,10 @@ if __name__ == "__main__":
if args.fp16: if args.fp16:
pipe = pipe.to(dtype=torch.float16) pipe = pipe.to(dtype=torch.float16)
if args.bf16:
pipe = pipe.to(dtype=torch.bfloat16)
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
+1 -1
View File
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup( setup(
name="diffusers", name="diffusers",
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.30.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.", description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.30.0.dev0" __version__ = "0.30.1"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
+37 -7
View File
@@ -1489,10 +1489,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict( def lora_state_dict(
cls, cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
**kwargs, **kwargs,
): ):
r""" r"""
@@ -1577,7 +1577,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
) )
return state_dict # For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas:
return state_dict, network_alphas
else:
return state_dict
def load_lora_weights( def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
@@ -1611,7 +1630,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
@@ -1619,6 +1640,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
self.load_lora_into_transformer( self.load_lora_into_transformer(
state_dict, state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
@@ -1628,7 +1650,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
text_encoder_state_dict, text_encoder_state_dict,
network_alphas=None, network_alphas=network_alphas,
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
@@ -1637,8 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -1647,6 +1668,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`SD3Transformer2DModel`): transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into. The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
@@ -1678,7 +1703,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if "lora_B" in key: if "lora_B" in key:
rank[key] = val.shape[1] rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
+2 -2
View File
@@ -23,6 +23,7 @@ from packaging import version
from ..utils import deprecate, is_transformers_available, logging from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import ( from .single_file_utils import (
SingleFileComponentError, SingleFileComponentError,
_is_legacy_scheduler_kwargs,
_is_model_weights_in_cached_folder, _is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer, _legacy_load_clip_tokenizer,
_legacy_load_safety_checker, _legacy_load_safety_checker,
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided # Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available(): if is_transformers_available():
import transformers import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
) )
elif is_diffusers_scheduler and is_legacy_loading: elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
loaded_sub_model = _legacy_load_scheduler( loaded_sub_model = _legacy_load_scheduler(
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
) )
+34 -9
View File
@@ -79,7 +79,10 @@ CHECKPOINT_KEY_NAMES = {
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight", "animatediff_rgb": "controlnet_cond_embedding.weight",
"flux": "double_blocks.0.img_attn.norm.key_norm.scale", "flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
} }
DIFFUSERS_DEFAULT_PIPELINE_PATHS = { DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -258,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading", "timestep_spacing": "leading",
} }
LDM_VAE_KEY = "first_stage_model." LDM_VAE_KEYS = ["first_stage_model.", "vae."]
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5 PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model." LDM_UNET_KEY = "model.diffusion_model."
@@ -267,8 +270,8 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.", "cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.", "conditioner.embedders.0.transformer.",
] ]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
@@ -318,6 +321,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
return weights_exist return weights_exist
def _is_legacy_scheduler_kwargs(kwargs):
return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
def load_single_file_checkpoint( def load_single_file_checkpoint(
pretrained_model_link_or_path, pretrained_model_link_or_path,
force_download=False, force_download=False,
@@ -516,8 +523,10 @@ def infer_diffusers_model_type(checkpoint):
else: else:
model_type = "animatediff_v3" model_type = "animatediff_v3"
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if "guidance_in.in_layer.bias" in checkpoint: if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev" model_type = "flux-dev"
else: else:
model_type = "flux-schnell" model_type = "flux-schnell"
@@ -1176,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
vae_state_dict = {} vae_state_dict = {}
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" vae_key = ""
for ldm_vae_key in LDM_VAE_KEYS:
if any(k.startswith(ldm_vae_key) for k in keys):
vae_key = ldm_vae_key
for key in keys: for key in keys:
if key.startswith(vae_key): if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1477,14 +1490,22 @@ def _legacy_load_scheduler(
if scheduler_type is not None: if scheduler_type is not None:
deprecation_message = ( deprecation_message = (
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`." "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
"Example:\n\n"
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
"scheduler = DDIMScheduler()\n"
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
) )
deprecate("scheduler_type", "1.0.0", deprecation_message) deprecate("scheduler_type", "1.0.0", deprecation_message)
if prediction_type is not None: if prediction_type is not None:
deprecation_message = ( deprecation_message = (
"Please configure an instance of a Scheduler with the appropriate `prediction_type` " "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
"and pass the object directly to the `scheduler` argument in `from_single_file`." "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
"Example:\n\n"
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
) )
deprecate("prediction_type", "1.0.0", deprecation_message) deprecate("prediction_type", "1.0.0", deprecation_message)
@@ -1881,6 +1902,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {} converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
+142
View File
@@ -1868,6 +1868,148 @@ class FluxAttnProcessor2_0:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class FusedCogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
class XFormersAttnAddedKVProcessor: class XFormersAttnAddedKVProcessor:
r""" r"""
Processor for implementing memory efficient attention using xFormers. Processor for implementing memory efficient attention using xFormers.
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d): class CogVideoXSafeConv3d(nn.Conv3d):
""" r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
""" """
@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args: Args:
in_channels (int): Number of channels in the input tensor. in_channels (`int`): Number of channels in the input tensor.
out_channels (int): Number of output channels. out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (int, optional): Stride of the convolution. Default is 1. stride (`int`, defaults to `1`): Stride of the convolution.
dilation (int, optional): Dilation rate of the convolution. Default is 1. dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (str, optional): Padding mode. Default is "constant". pad_mode (`str`, defaults to `"constant"`): Padding mode.
""" """
def __init__( def __init__(
@@ -118,19 +118,12 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
kernel_size = self.time_kernel_size kernel_size = self.time_kernel_size
if kernel_size == 1: if kernel_size > 1:
return inputs cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = inputs.transpose(0, dim) )
inputs = torch.cat(cached_inputs + [inputs], dim=2)
if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
inputs = inputs.transpose(0, dim).contiguous()
return inputs return inputs
def _clear_fake_context_parallel_cache(self): def _clear_fake_context_parallel_cache(self):
@@ -138,16 +131,17 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
input_parallel = self.fake_context_parallel_forward(inputs) inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache() self._clear_fake_context_parallel_cache()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel) output = self.conv(inputs)
output = output_parallel
return output return output
@@ -163,6 +157,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
The number of channels for input to group normalization layer, and output of the spatial norm layer. The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`): zq_channels (`int`):
The number of channels for the quantized vector as described in the paper. The number of channels for the quantized vector as described in the paper.
groups (`int`):
Number of groups to separate the channels into for group normalization.
""" """
def __init__( def __init__(
@@ -197,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
A 3D ResNet block used in the CogVideoX model. A 3D ResNet block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (Optional[int], optional): Number of input channels.
Number of output channels. If None, defaults to `in_channels`. Default is None. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
temb_channels (int, optional): Number of time embedding channels. Default is 512. dropout (`float`, defaults to `0.0`):
groups (int, optional): Number of groups for group normalization. Default is 32. Dropout rate.
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. temb_channels (`int`, defaults to `512`):
non_linearity (str, optional): Activation function to use. Default is "swish". Number of time embedding channels.
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. groups (`int`, defaults to `32`):
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. Number of groups to separate the channels into for group normalization.
pad_mode (str, optional): Padding mode. Default is "first". eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
def __init__( def __init__(
@@ -309,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
A downsampling block used in the CogVideoX model. A downsampling block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (int): Number of output channels. Number of input channels.
temb_channels (int): Number of time embedding channels. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
num_layers (int, optional): Number of layers in the block. Default is 1. temb_channels (`int`, defaults to `512`):
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. Number of time embedding channels.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". num_layers (`int`, defaults to `1`):
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. Number of resnet layers.
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. dropout (`float`, defaults to `0.0`):
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. Dropout rate.
compress_time (bool, optional): If True, apply temporal compression. Default is False. resnet_eps (`float`, defaults to `1e-6`):
pad_mode (str, optional): Padding mode. Default is "first". Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
add_downsample (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@@ -405,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
A middle block used in the CogVideoX model. A middle block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
temb_channels (int): Number of time embedding channels. Number of input channels.
dropout (float, optional): Dropout rate. Default is 0.0. temb_channels (`int`, defaults to `512`):
num_layers (int, optional): Number of layers in the block. Default is 1. Number of time embedding channels.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. dropout (`float`, defaults to `0.0`):
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". Dropout rate.
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. num_layers (`int`, defaults to `1`):
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. Number of resnet layers.
pad_mode (str, optional): Padding mode. Default is "first". resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@@ -480,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
An upsampling block used in the CogVideoX model. An upsampling block used in the CogVideoX model.
Args: Args:
in_channels (int): Number of input channels. in_channels (`int`):
out_channels (int): Number of output channels. Number of input channels.
temb_channels (int): Number of time embedding channels. out_channels (`int`, *optional*):
dropout (float, optional): Dropout rate. Default is 0.0. Number of output channels. If None, defaults to `in_channels`.
num_layers (int, optional): Number of layers in the block. Default is 1. temb_channels (`int`, defaults to `512`):
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. Number of time embedding channels.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". dropout (`float`, defaults to `0.0`):
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. Dropout rate.
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. num_layers (`int`, defaults to `1`):
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. Number of resnet layers.
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. resnet_eps (`float`, defaults to `1e-6`):
compress_time (bool, optional): If True, apply temporal compression. Default is False. Epsilon value for normalization layers.
pad_mode (str, optional): Padding mode. Default is "first". resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, defaults to `16`):
The dimension to use for spatial norm if it is to be used instead of group norm.
add_upsample (`bool`, defaults to `True`):
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
""" """
def __init__( def __init__(
@@ -587,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
options. options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block. The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block. The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32): norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization. The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@@ -723,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block. The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2): layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block. The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32): norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization. The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@@ -871,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Tuple of block output channels. Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size. sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215): scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
@@ -911,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
norm_num_groups: int = 32, norm_num_groups: int = 32,
temporal_compression_ratio: float = 4, temporal_compression_ratio: float = 4,
sample_size: int = 256, sample_height: int = 480,
sample_width: int = 720,
scaling_factor: float = 1.15258426, scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None, shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None, latents_mean: Optional[Tuple[float]] = None,
@@ -950,25 +982,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
self.tile_sample_min_size = self.config.sample_size # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
sample_size = ( # recommended because the temporal parts of the VAE, here, are tricky to understand.
self.config.sample_size[0] # If you decode X latent frames together, the number of output frames is:
if isinstance(self.config.sample_size, (list, tuple)) # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
else self.config.sample_size #
# Example with num_latent_frames_batch_size = 2:
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 6 * 8 = 48 frames
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 1 * 9 + 5 * 8 = 49 frames
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
) )
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def clear_fake_context_parallel_cache(self): def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules(): for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d): if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache() module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook @apply_forward_hook
def encode( def encode(
self, x: torch.Tensor, return_dict: bool = True self, x: torch.Tensor, return_dict: bool = True
@@ -993,8 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,) return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook @apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
""" """
Decode a batch of images. Decode a batch of images.
@@ -1007,13 +1145,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
[`~models.vae.DecoderOutput`] or `tuple`: [`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned. returned.
""" """
if self.post_quant_conv is not None: if self.use_slicing and z.shape[0] > 1:
z = self.post_quant_conv(z) decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
dec = self.decoder(z) decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict: if not return_dict:
return (dec,) return (dec,)
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def forward( def forward(
+84
View File
@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
return embeds return embeds
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
Args:
embed_dim: (`int`):
The embedding dimension size, corresponding to hidden_size_head.
crops_coords (`Tuple[int]`):
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the spatial positional embedding (height, width).
temporal_size (`int`):
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
""" """
RoPE for image tokens with 2d structure. RoPE for image tokens with 2d structure.
@@ -68,6 +68,21 @@ class AuraFlowPatchEmbed(nn.Module):
self.height, self.width = height // patch_size, width // patch_size self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size self.base_size = height // patch_size
def pe_selection_index_based_on_dim(self, h, w):
# select subset of positional embedding based on H, W, where H, W is size of latent
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()
def forward(self, latent): def forward(self, latent):
batch_size, num_channels, height, width = latent.size() batch_size, num_channels, height, width = latent.size()
latent = latent.view( latent = latent.view(
@@ -80,7 +95,8 @@ class AuraFlowPatchEmbed(nn.Module):
) )
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent) latent = self.proj(latent)
return latent + self.pos_embed pe_index = self.pe_selection_index_based_on_dim(height, width)
return latent + self.pos_embed[:, pe_index]
# Taken from the original Aura flow inference code. # Taken from the original Aura flow inference code.
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
@@ -37,13 +38,20 @@ class CogVideoXBlock(nn.Module):
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters: Parameters:
dim (`int`): The number of channels in the input and output. dim (`int`):
num_attention_heads (`int`): The number of heads to use for multi-head attention. The number of channels in the input and output.
attention_head_dim (`int`): The number of channels in each head. num_attention_heads (`int`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. The number of heads to use for multi-head attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_head_dim (`int`):
attention_bias (: The number of channels in each head.
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
qk_norm (`bool`, defaults to `True`): qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention. Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
@@ -90,6 +98,7 @@ class CogVideoXBlock(nn.Module):
eps=1e-6, eps=1e-6,
bias=attention_bias, bias=attention_bias,
out_bias=attention_out_bias, out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
) )
# 2. Feed Forward # 2. Feed Forward
@@ -109,24 +118,24 @@ class CogVideoXBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb hidden_states, encoder_hidden_states, temb
) )
# attention # attention
text_length = norm_encoder_hidden_states.size(1) attn_hidden_states, attn_encoder_hidden_states = self.attn1(
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
# them in cross-attention individually
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attn_output = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=None, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
) )
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -137,8 +146,9 @@ class CogVideoXBlock(nn.Module):
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
@@ -147,36 +157,53 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters: Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. num_attention_heads (`int`, defaults to `30`):
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. The number of heads to use for multi-head attention.
in_channels (`int`, *optional*): attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input. The number of channels in the input.
out_channels (`int`, *optional*): out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output. The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. flip_sin_to_cos (`bool`, defaults to `True`):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. Whether to flip the sin to cos in the time embedding.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. time_embed_dim (`int`, defaults to `512`):
attention_bias (`bool`, *optional*): Output dimension of timestep embeddings.
Configure if the `TransformerBlocks` attention should contain a bias parameter. text_embed_dim (`int`, defaults to `4096`):
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). Input dimension of text embeddings from the text encoder.
This is fixed during training since it is used to learn a number of position embeddings. num_layers (`int`, defaults to `30`):
patch_size (`int`, *optional*): The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer. The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. temporal_compression_ratio (`int`, defaults to `4`):
num_embeds_ada_norm ( `int`, *optional*): The compression ratio across the temporal dimension. See documentation for `sample_frames`.
The number of diffusion steps used during training. Pass if at least one of the norm_layers is max_text_seq_length (`int`, defaults to `226`):
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are The maximum sequence length of the input text embeddings.
added to the hidden states. During inference, you can denoise for up to but not more steps than activation_fn (`str`, defaults to `"gelu-approximate"`):
`num_embeds_ada_norm`. Activation function to use in feed-forward.
norm_type (`str`, *optional*, defaults to `"layer_norm"`): timestep_activation_fn (`str`, defaults to `"silu"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): norm_elementwise_affine (`bool`, defaults to `True`):
Whether or not to use elementwise affine in normalization layers. Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. norm_eps (`float`, defaults to `1e-5`):
caption_channels (`int`, *optional*): The epsilon value to use in normalization layers.
The number of channels in the caption embeddings. spatial_interpolation_scale (`float`, defaults to `1.875`):
video_length (`int`, *optional*): Scaling factor to apply in 3D positional embeddings across spatial dimensions.
The number of frames in the video-like data. temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@@ -186,7 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self, self,
num_attention_heads: int = 30, num_attention_heads: int = 30,
attention_head_dim: int = 64, attention_head_dim: int = 64,
in_channels: Optional[int] = 16, in_channels: int = 16,
out_channels: Optional[int] = 16, out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
@@ -207,6 +234,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875, spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
): ):
super().__init__() super().__init__()
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
@@ -271,12 +299,113 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True, return_dict: bool = True,
): ):
batch_size, num_frames, channels, height, width = hidden_states.shape batch_size, num_frames, channels, height, width = hidden_states.shape
@@ -295,16 +424,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding # 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2) text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states) hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :] hidden_states = hidden_states[:, text_seq_length:]
# 5. Transformer blocks # 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -320,6 +451,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
emb, emb,
image_rotary_emb,
**ckpt_kwargs, **ckpt_kwargs,
) )
else: else:
@@ -327,15 +459,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=emb, temb=emb,
image_rotary_emb=image_rotary_emb,
) )
hidden_states = self.norm_final(hidden_states) if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 6. Final block # 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify # 6. Unpatchify
p = self.config.patch_size p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
@@ -23,6 +23,7 @@ from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring from ...utils import BaseOutput, logging, replace_example_docstring
@@ -40,6 +41,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import CogVideoXPipeline >>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video >>> from diffusers.utils import export_to_video
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = ( >>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
@@ -55,6 +57,25 @@ EXAMPLE_DOC_STRING = """
""" """
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
@@ -332,20 +353,11 @@ class CogVideoXPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def decode_latents(self, latents: torch.Tensor, num_seconds: int): def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
frames = [] frames = self.vae.decode(latents).sample
for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames)
self.vae.clear_fake_context_parallel_cache()
frames = torch.cat(frames, dim=2)
return frames return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -418,6 +430,46 @@ class CogVideoXPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
@@ -438,8 +490,7 @@ class CogVideoXPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480, height: int = 480,
width: int = 720, width: int = 720,
num_frames: int = 48, num_frames: int = 49,
fps: int = 8,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
guidance_scale: float = 6, guidance_scale: float = 6,
@@ -534,9 +585,10 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
assert ( if num_frames > 49:
num_frames <= 48 and num_frames % fps == 0 and fps == 8 raise ValueError(
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -593,7 +645,6 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
latent_channels, latent_channels,
@@ -609,7 +660,14 @@ class CogVideoXPipeline(DiffusionPipeline):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop # 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -630,6 +688,7 @@ class CogVideoXPipeline(DiffusionPipeline):
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = noise_pred.float() noise_pred = noise_pred.float()
@@ -673,7 +732,7 @@ class CogVideoXPipeline(DiffusionPipeline):
progress_bar.update() progress_bar.update()
if not output_type == "latent": if not output_type == "latent":
video = self.decode_latents(latents, num_frames // fps) video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type) video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else: else:
video = latents video = latents
+50 -3
View File
@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
from .import_utils import BACKENDS_MAPPING, is_opencv_available from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available
from .logging import get_logger from .logging import get_logger
@@ -112,9 +112,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data)) f.writelines("\n".join(combined_data))
def export_to_video( def _legacy_export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
) -> str: ):
if is_opencv_available(): if is_opencv_available():
import cv2 import cv2
else: else:
@@ -134,4 +134,51 @@ def export_to_video(
for i in range(len(video_frames)): for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img) video_writer.write(img)
return output_video_path
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
) -> str:
# TODO: Dhruv. Remove by Diffusers release 0.33.0
# Added to prevent breaking existing code
if not is_imageio_available():
logger.warning(
(
"It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n"
"These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n"
"Support for the OpenCV backend will be deprecated in a future Diffusers version"
)
)
return _legacy_export_to_video(video_frames, output_video_path, fps)
if is_imageio_available():
import imageio
else:
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video"))
try:
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
raise AttributeError(
(
"Found an existing imageio backend in your environment. Attempting to export video with imageio. \n"
"Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg"
)
)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], np.ndarray):
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
return output_video_path return output_video_path
+19
View File
@@ -330,6 +330,15 @@ except importlib_metadata.PackageNotFoundError:
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) _is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
_imageio_available = importlib.util.find_spec("imageio") is not None
if _imageio_available:
try:
_imageio_version = importlib_metadata.version("imageio")
logger.debug(f"Successfully imported imageio version {_imageio_version}")
except importlib_metadata.PackageNotFoundError:
_imageio_available = False
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
@@ -447,6 +456,10 @@ def is_sentencepiece_available():
return _sentencepiece_available return _sentencepiece_available
def is_imageio_available():
return _imageio_available
# docstyle-ignore # docstyle-ignore
FLAX_IMPORT_ERROR = """ FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -575,6 +588,11 @@ BITSANDBYTES_IMPORT_ERROR = """
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
""" """
# docstyle-ignore
IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -599,6 +617,7 @@ BACKENDS_MAPPING = OrderedDict(
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
] ]
) )
+16 -12
View File
@@ -6,7 +6,7 @@ import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import requests import requests
from .import_utils import BACKENDS_MAPPING, is_opencv_available from .import_utils import BACKENDS_MAPPING, is_imageio_available
def load_image( def load_image(
@@ -81,7 +81,8 @@ def load_video(
if is_url: if is_url:
video_data = requests.get(video, stream=True).raw video_data = requests.get(video, stream=True).raw
video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name suffix = os.path.splitext(video)[1] or ".mp4"
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
was_tempfile_created = True was_tempfile_created = True
with open(video_path, "wb") as f: with open(video_path, "wb") as f:
f.write(video_data.read()) f.write(video_data.read())
@@ -99,19 +100,22 @@ def load_video(
pass pass
else: else:
if is_opencv_available(): if is_imageio_available():
import cv2 import imageio
else: else:
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video")) raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
video_capture = cv2.VideoCapture(video) try:
success, frame = video_capture.read() imageio.plugins.ffmpeg.get_exe()
while success: except AttributeError:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) raise AttributeError(
pil_images.append(PIL.Image.fromarray(frame)) "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
success, frame = video_capture.read() )
video_capture.release() with imageio.get_reader(video) as reader:
# Read all frames
for frame in reader:
pil_images.append(PIL.Image.fromarray(frame))
if was_tempfile_created: if was_tempfile_created:
os.remove(video_path) os.remove(video_path)
+57 -2
View File
@@ -12,19 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import tempfile
import unittest import unittest
import numpy as np
import safetensors.torch
import torch import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
if is_peft_available():
from peft.utils import get_peft_model_state_dict
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
@@ -90,3 +97,51 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_inputs.update({"generator": generator}) pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test.
if "transformer" in k and "to_k" in k and "lora_A" in k:
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
"patch_size": 2,
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -20,6 +20,7 @@ from diffusers import (
) )
from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -329,6 +330,13 @@ class AnimateDiffControlNetPipelineFastTests(
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device) inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs) pipe(**inputs)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
def test_free_init(self): def test_free_init(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
@@ -19,6 +19,7 @@ from diffusers import (
UNetMotionModel, UNetMotionModel,
) )
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -393,6 +394,13 @@ class AnimateDiffSparseControlNetPipelineFastTests(
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device) inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs) pipe(**inputs)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
def test_free_init(self): def test_free_init(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components) pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components)
+80 -8
View File
@@ -30,7 +30,12 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
)
enable_full_determinism() enable_full_determinism()
@@ -125,11 +130,6 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Cannot reduce because convolution kernel becomes bigger than sample # Cannot reduce because convolution kernel becomes bigger than sample
"height": 16, "height": 16,
"width": 16, "width": 16,
# TODO(aryan): improve this
# Cannot make this lower due to assert condition in pipeline at the moment.
# The reason why 8 can't be used here is due to how context-parallel cache works where the first
# second of video is decoded from latent frames (0, 3) instead of [(0, 2), (2, 3)]. If 8 is used,
# the number of output frames that you get are 5.
"num_frames": 8, "num_frames": 8,
"max_sequence_length": 16, "max_sequence_length": 16,
"output_type": "pt", "output_type": "pt",
@@ -148,8 +148,8 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16) expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10) self.assertLessEqual(max_diff, 1e10)
@@ -250,6 +250,78 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"Attention slicing should not affect the inference results", "Attention slicing should not affect the inference results",
) )
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@unittest.skip("xformers attention processor does not exist for CogVideoX")
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow @slow
@require_torch_gpu @require_torch_gpu
+8
View File
@@ -28,6 +28,7 @@ from diffusers import (
LattePipeline, LattePipeline,
LatteTransformer3DModel, LatteTransformer3DModel,
) )
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
@@ -256,6 +257,13 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1.0) self.assertLess(max_diff, 1.0)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
@slow @slow
@require_torch_gpu @require_torch_gpu