Compare commits

..

1 Commits

Author SHA1 Message Date
Arthur Zucker 6da800931e test forcing tokenizers 2024-08-08 17:55:45 +02:00
116 changed files with 468 additions and 8365 deletions
-1
View File
@@ -202,7 +202,6 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/microsoft/TaskMatrix
- https://github.com/invoke-ai/InvokeAI
- https://github.com/InstantID/InstantID
- https://github.com/apple/ml-stable-diffusion
- https://github.com/Sanster/lama-cleaner
- https://github.com/IDEA-Research/Grounded-Segment-Anything
+23 -20
View File
@@ -15,7 +15,9 @@
# CogVideoX
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
<!-- TODO: update paper with ArXiv link when ready. -->
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
The abstract from the paper is:
@@ -41,42 +43,43 @@ from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
```python
pipe.transformer.to(memory_format=torch.channels_last)
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```
Finally, compile the components and run inference:
```python
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works well with long and well-described prompts
# CogVideoX works very well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
The [benchmark](TODO: link) results on an 80GB A100 machine are:
```
Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: 76.27 seconds.
Without torch.compile(): Average inference time: TODO seconds.
With torch.compile(): Average inference time: TODO seconds.
```
### Memory optimization
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
+2 -16
View File
@@ -1,4 +1,4 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
@@ -22,16 +22,7 @@ The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
<Tip>
@@ -44,10 +35,5 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## StableDiffusion3ControlNetInpaintingPipeline
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline
- all
- __call__
## StableDiffusion3PipelineOutput
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](https://github.com/Kwai-Kolors/Kolors). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](kwai-kolors@kuaishou.com). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
The abstract from the technical report is:
@@ -74,7 +74,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pipe = KolorsPipeline.from_pretrained(
"Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
)
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter(
+2 -2
View File
@@ -20,7 +20,7 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
@@ -46,7 +46,7 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
## KolorsPAGPipeline
[[autodoc]] KolorsPAGPipeline
- all
- __call__
- __call__
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
@@ -48,7 +48,7 @@ accelerate launch run_distributed.py --num_processes=2
<Tip>
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
</Tip>
@@ -108,4 +108,4 @@ torchrun run_distributed.py --nproc_per_node=2
```
> [!TIP]
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
+15 -15
View File
@@ -14,9 +14,9 @@ specific language governing permissions and limitations under the License.
It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
This guide will show you how to merge LoRAs using the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
This guide will show you how to merge LoRAs using the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style](https://huggingface.co/KappaNeuro/studio-ghibli-style) and [Norod78/sdxl-chalkboarddrawing-lora](https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora) LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style]() and [Norod78/sdxl-chalkboarddrawing-lora]() LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
```py
from diffusers import DiffusionPipeline
@@ -29,7 +29,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
## set_adapters
The [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
The [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
```py
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
@@ -47,19 +47,19 @@ image
## add_weighted_adapter
> [!WARNING]
> This is an experimental method that adds PEFTs [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
> This is an experimental method that adds PEFTs [`~peft.LoraModel.add_weighted_adapter`] method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
The [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
The [`~peft.LoraModel.add_weighted_adapter`] method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
```bash
pip install -U diffusers peft
```
There are three steps to merge LoRAs with the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method:
There are three steps to merge LoRAs with the [`~peft.LoraModel.add_weighted_adapter`] method:
1. Create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the underlying model and LoRA checkpoint.
1. Create a [`~peft.PeftModel`] from the underlying model and LoRA checkpoint.
2. Load a base UNet model and the LoRA adapters.
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice.
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice.
Let's dive deeper into what these steps entail.
@@ -92,7 +92,7 @@ pipeline = DiffusionPipeline.from_pretrained(
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
```
Now you'll create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
Now you'll create a [`~peft.PeftModel`] from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
```python
from peft import get_peft_model, LoraConfig
@@ -112,7 +112,7 @@ ikea_peft_model.load_state_dict(original_state_dict, strict=True)
> [!TIP]
> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
Repeat this process to create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
Repeat this process to create a [`~peft.PeftModel`] from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
```python
pipeline.delete_adapters("ikea")
@@ -148,7 +148,7 @@ model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_saf
model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
```
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
> [!WARNING]
> Keep in mind the LoRAs need to have the same rank to be merged!
@@ -182,9 +182,9 @@ image
## fuse_lora
Both the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
Both the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
@@ -199,7 +199,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
```
Fuse these LoRAs into the UNet with the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method because it wont work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
Fuse these LoRAs into the UNet with the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method because it wont work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
```py
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
@@ -226,7 +226,7 @@ image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
image
```
You can call [`~~loaders.lora_base.LoraBaseMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
You can call [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
```py
pipeline.unfuse_lora()
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
class MarigoldDepthOutput(BaseOutput):
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
-195
View File
@@ -1,195 +0,0 @@
# DreamBooth training example for FLUX.1 [dev]
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux"
accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora"
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### Text Encoder Training
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--train_text_encoder\
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--seed="0" \
--push_to_hub
```
## Other notes
Thanks to `bghira` for their help with reviewing & insight sharing ♥️
@@ -1,8 +0,0 @@
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
-203
View File
@@ -1,203 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import sys
import tempfile
from diffusers import DiffusionPipeline, FluxTransformer2DModel
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_flux.py"
def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
pipe(self.instance_prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
@@ -1,165 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
File diff suppressed because it is too large Load Diff
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -1454,7 +1454,7 @@ def main(args):
)
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
if not args.train_text_encoder and train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -64,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -2,8 +2,8 @@ diffusers==0.20.1
accelerate==0.23.0
transformers==4.38.0
peft==0.5.0
torch==2.2.0
torch==2.0.1
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.4
Jinja2==3.1.3
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -478,7 +478,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are available in the dataset",
help="debug loss for each image, if filenames are awailable in the dataset",
)
if input_args is not None:
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
-3
View File
@@ -109,9 +109,6 @@ import torch
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
repo_id_embeds = "path-to-your-learned-embeds"
pipe.load_textual_inversion(repo_id_embeds)
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+1 -22
View File
@@ -23,25 +23,4 @@ accelerate launch textual_inversion_sdxl.py \
--output_dir="./textual_inversion_cat_sdxl"
```
Training of both text encoders is supported.
### Inference Example
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionXLPipeline
import torch
model_id = "./textual_inversion_cat_sdxl"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack-prompt_2.png")
```
For now, only training of the first text encoder is supported.
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
@@ -135,7 +135,7 @@ def log_validation(
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
unet=unet,
@@ -678,54 +678,36 @@ def main():
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1 or len(token_ids_2) > 1:
if len(token_ids) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
initializer_token_id_2 = token_ids_2[0]
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder_1.resize_token_embeddings(len(tokenizer_1))
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder_1.get_input_embeddings().weight.data
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad():
for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
for token_id in placeholder_token_ids_2:
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
# Freeze vae and unet
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -764,11 +746,7 @@ def main():
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
# only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
],
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -808,10 +786,9 @@ def main():
)
text_encoder_1.train()
text_encoder_2.train()
# Prepare everything with our `accelerator`.
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -889,13 +866,11 @@ def main():
# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs):
text_encoder_1.train()
text_encoder_2.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate([text_encoder_1, text_encoder_2]):
with accelerator.accumulate(text_encoder_1):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor
@@ -917,7 +892,9 @@ def main():
.hidden_states[-2]
.to(dtype=weight_dtype)
)
encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
encoder_output_2 = text_encoder_2(
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
)
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
original_size = [
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
@@ -961,16 +938,11 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -988,16 +960,6 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
@@ -1072,7 +1034,7 @@ def main():
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
text_encoder_2=text_encoder_2,
vae=vae,
unet=unet,
tokenizer=tokenizer_1,
@@ -1090,16 +1052,6 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = "learned_embeds_2.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if args.push_to_hub:
save_model_card(
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
check_min_version("0.30.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+2 -1
View File
@@ -135,6 +135,7 @@ _deps = [
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
"tokenizers==0.20.0rc1"
]
# this is a lookup table with items like:
@@ -254,7 +255,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -6
View File
@@ -1,4 +1,4 @@
__version__ = "0.31.0.dev0"
__version__ = "0.30.0.dev0"
from typing import TYPE_CHECKING
@@ -88,7 +88,6 @@ else:
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"FluxControlNetModel",
"FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -255,7 +254,6 @@ else:
"CLIPImageProjection",
"CogVideoXPipeline",
"CycleDiffusionPipeline",
"FluxControlNetPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -310,7 +308,6 @@ else:
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
@@ -552,7 +549,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
FluxControlNetModel,
FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -697,7 +693,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CLIPImageProjection,
CogVideoXPipeline,
CycleDiffusionPipeline,
FluxControlNetPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
+1 -5
View File
@@ -222,11 +222,7 @@ class IPAdapterMixin:
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
clip_image_size = self.image_encoder.config.image_size
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
-2
View File
@@ -35,7 +35,6 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
@@ -88,7 +87,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
)
from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel
+1 -1
View File
@@ -449,7 +449,7 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -60,8 +60,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
_always_upcast_modules = ["MaskConditionDecoder"]
@register_to_config
def __init__(
self,
@@ -70,7 +70,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_always_upcast_modules = ["Decoder"]
@register_to_config
def __init__(
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d):
r"""
"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
"""
@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int`, defaults to `1`): Stride of the convolution.
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (`str`, defaults to `"constant"`): Padding mode.
in_channels (int): Number of channels in the input tensor.
out_channels (int): Number of output channels.
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
stride (int, optional): Stride of the convolution. Default is 1.
dilation (int, optional): Dilation rate of the convolution. Default is 1.
pad_mode (str, optional): Padding mode. Default is "constant".
"""
def __init__(
@@ -118,12 +118,19 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
if kernel_size == 1:
return inputs
inputs = inputs.transpose(0, dim)
if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
inputs = inputs.transpose(0, dim).contiguous()
return inputs
def _clear_fake_context_parallel_cache(self):
@@ -131,17 +138,16 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
input_parallel = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
@@ -157,8 +163,6 @@ class CogVideoXSpatialNorm3D(nn.Module):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
groups (`int`):
Number of groups to separate the channels into for group normalization.
"""
def __init__(
@@ -193,26 +197,17 @@ class CogVideoXResnetBlock3D(nn.Module):
A 3D ResNet block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
in_channels (int): Number of input channels.
out_channels (Optional[int], optional):
Number of output channels. If None, defaults to `in_channels`. Default is None.
dropout (float, optional): Dropout rate. Default is 0.0.
temb_channels (int, optional): Number of time embedding channels. Default is 512.
groups (int, optional): Number of groups for group normalization. Default is 32.
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
non_linearity (str, optional): Activation function to use. Default is "swish".
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
"""
def __init__(
@@ -314,28 +309,18 @@ class CogVideoXDownBlock3D(nn.Module):
A downsampling block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
add_downsample (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
"""
_supports_gradient_checkpointing = True
@@ -420,24 +405,15 @@ class CogVideoXMidBlock3D(nn.Module):
A middle block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
in_channels (int): Number of input channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
"""
_supports_gradient_checkpointing = True
@@ -504,30 +480,19 @@ class CogVideoXUpBlock3D(nn.Module):
An upsampling block used in the CogVideoX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, defaults to `16`):
The dimension to use for spatial norm if it is to be used instead of group norm.
add_upsample (`bool`, defaults to `True`):
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
"""
def __init__(
@@ -622,12 +587,14 @@ class CogVideoXEncoder3D(nn.Module):
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
_supports_gradient_checkpointing = True
@@ -756,12 +723,14 @@ class CogVideoXDecoder3D(nn.Module):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
_supports_gradient_checkpointing = True
@@ -942,8 +911,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_height: int = 480,
sample_width: int = 720,
sample_size: int = 256,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
@@ -982,105 +950,25 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
self.use_tiling = False
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand.
# If you decode X latent frames together, the number of output frames is:
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
#
# Example with num_latent_frames_batch_size = 2:
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 6 * 8 = 48 frames
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 1 * 9 + 5 * 8 = 49 frames
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def _clear_fake_context_parallel_cache(self):
def clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -1105,34 +993,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
@@ -1145,111 +1007,13 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
@@ -192,7 +192,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["TemporalDecoder"]
@register_to_config
def __init__(
@@ -317,7 +317,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = False
_always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"]
@register_to_config
def __init__(
@@ -330,7 +330,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
@@ -71,8 +71,6 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
_always_upcast_modules = ["Decoder", "VectorQuantizer"]
@register_to_config
def __init__(
self,
-374
View File
@@ -1,374 +0,0 @@
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin
from ..models.attention_processor import AttentionProcessor
from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from .controlnet import BaseOutput, zero_module
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
from .modeling_outputs import Transformer2DModelOutput
from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FluxControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers=4,
num_single_layers=10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = transformer.config
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
#
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return FluxControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
+1 -2
View File
@@ -55,7 +55,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
):
super().__init__()
default_out_channels = in_channels
@@ -99,7 +98,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels + extra_conditioning_channels,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
)
-74
View File
@@ -263,80 +263,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def enable_layerwise_upcasting(self, upcast_dtype=None):
r"""
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
individual modules in the model to the appropriate dtype right before the foward pass.
The module is then moved back to the low memory dtype after the foward pass.
"""
upcast_dtype = upcast_dtype or torch.float32
original_dtype = self.dtype
def upcast_dtype_hook_fn(module, *args, **kwargs):
module = module.to(upcast_dtype)
def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
module = module.to(original_dtype)
def fn_recursive_upcast(module):
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks.
e.g.
```
class MyModel(nn.Module):
def forward(self, x):
dtype = next(iter(self.blocks.parameters())).dtype
x = self.blocks(x) + torch.ones(x.size()).to(dtype)
```
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
their `forward` method is called. We need to add the upcast hook on the entire module in order for the
operation to work.
The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
entirely, rather than layerwise.
"""
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
# Upcast entire module and exist recursion
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
return
has_children = list(module.children())
if not has_children:
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def disable_layerwise_upcasting(self):
def fn_recursive_upcast(module):
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
return
has_children = list(module.children())
if not has_children:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -68,21 +68,6 @@ class AuraFlowPatchEmbed(nn.Module):
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
def pe_selection_index_based_on_dim(self, h, w):
# select subset of positional embedding based on H, W, where H, W is size of latent
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
@@ -95,8 +80,7 @@ class AuraFlowPatchEmbed(nn.Module):
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
pe_index = self.pe_selection_index_based_on_dim(height, width)
return latent + self.pos_embed[:, pe_index]
return latent + self.pos_embed
# Taken from the original Aura flow inference code.
@@ -274,9 +258,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_supports_gradient_checkpointing = True
_always_upcast_modules = ["AuraFlowPatchEmbed"]
@register_to_config
def __init__(
@@ -458,15 +440,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
# Apply patch embedding, timestep embedding, and project the caption embeddings.
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype)
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
temb = self.time_step_proj(temb)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = torch.cat(
[
self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1),
encoder_hidden_states,
],
dim=1,
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
)
# MMDiT blocks.
@@ -37,20 +37,13 @@ class CogVideoXBlock(nn.Module):
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
@@ -154,53 +147,36 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, defaults to `30`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
out_channels (`int`, *optional*):
The number of channels in the output.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
sample_frames (`int`, defaults to `49`):
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
patch_size (`int`, defaults to `2`):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
patch_size (`int`, *optional*):
The size of the patches to use in the patch embedding layer.
temporal_compression_ratio (`int`, defaults to `4`):
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
max_text_seq_length (`int`, defaults to `226`):
The maximum sequence length of the input text embeddings.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
norm_elementwise_affine (`bool`, defaults to `True`):
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states. During inference, you can denoise for up to but not more steps than
`num_embeds_ada_norm`.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
spatial_interpolation_scale (`float`, defaults to `1.875`):
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
temporal_interpolation_scale (`float`, defaults to `1.0`):
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
caption_channels (`int`, *optional*):
The number of channels in the caption embeddings.
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_supports_gradient_checkpointing = True
@@ -210,7 +186,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: int = 16,
in_channels: Optional[int] = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
@@ -328,7 +304,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 4. Transformer blocks
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
@@ -355,11 +331,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm_final(hidden_states)
# 5. Final block
# 6. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
# 7. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
@@ -65,7 +65,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
@@ -244,8 +244,6 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
_always_upcast_modules = ["HunyuanDiTAttentionPool"]
@register_to_config
def __init__(
self,
@@ -486,9 +484,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding.to(encoder_hidden_states.dtype)
)
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
skips = []
for layer, block in enumerate(self.blocks):
@@ -64,7 +64,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
@@ -302,9 +301,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None].to(embedded_timestep.dtype) + embedded_timestep[:, None]).chunk(
2, dim=1
)
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
@@ -19,7 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -79,7 +79,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
@@ -248,14 +247,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
@@ -423,8 +414,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
shift, scale = (
self.scale_shift_table[None].to(embedded_timestep.dtype)
+ embedded_timestep[:, None].to(self.scale_shift_table.device)
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
@@ -289,7 +289,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=hidden_states.dtype)
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
if self.embedding_proj_norm is not None:
@@ -1,4 +1,4 @@
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -251,7 +250,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(
@@ -323,8 +321,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
@@ -381,7 +377,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
@@ -415,12 +410,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
@@ -451,15 +440,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
image_rotary_emb=image_rotary_emb,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
@@ -54,7 +54,6 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
+1 -1
View File
@@ -283,7 +283,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
@@ -641,7 +641,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+1 -1
View File
@@ -590,7 +590,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = t_emb.to(dtype=self.dtype)
t_emb = self.time_embedding(t_emb, timestep_cond)
# 2. FPS
@@ -2152,7 +2152,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
+5 -4
View File
@@ -124,7 +124,7 @@ else:
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"]
_import_structure["flux"] = ["FluxPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -173,7 +173,6 @@ else:
_import_structure["controlnet_sd3"].extend(
[
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
]
)
_import_structure["deepfloyd_if"] = [
@@ -466,7 +465,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .controlnet_hunyuandit import (
HunyuanDiTControlNetPipeline,
)
from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
from .controlnet_sd3 import (
StableDiffusion3ControlNetPipeline,
)
from .controlnet_xs import (
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
@@ -493,7 +494,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .flux import FluxControlNetPipeline, FluxPipeline
from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
-2
View File
@@ -49,7 +49,6 @@ from .kandinsky2_2 import (
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .lumina import LuminaText2ImgPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
@@ -107,7 +106,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("lumina", LuminaText2ImgPipeline),
]
)
@@ -332,11 +332,20 @@ class CogVideoXPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
frames = []
for i in range(num_seconds):
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
frames.append(current_frames)
self.vae.clear_fake_context_parallel_cache()
frames = torch.cat(frames, dim=2)
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -429,7 +438,8 @@ class CogVideoXPipeline(DiffusionPipeline):
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 49,
num_frames: int = 48,
fps: int = 8,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
@@ -524,10 +534,9 @@ class CogVideoXPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if num_frames > 49:
raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
)
assert (
num_frames <= 48 and num_frames % fps == 0 and fps == 8
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -584,6 +593,7 @@ class CogVideoXPipeline(DiffusionPipeline):
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
@@ -663,7 +673,7 @@ class CogVideoXPipeline(DiffusionPipeline):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.decode_latents(latents, num_frames // fps)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
@@ -23,9 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
_import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
"StableDiffusion3ControlNetInpaintingPipeline"
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -36,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
try:
if not (is_transformers_available() and is_flax_available()):
-2
View File
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_flux import FluxPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
else:
import sys
@@ -677,13 +677,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -693,6 +686,13 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
noise_pred = self.transformer(
hidden_states=latents,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -1,861 +0,0 @@
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import FluxPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.utils import load_image
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
>>> image = pipe(
... prompt,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... num_inference_steps=28,
... guidance_scale=3.5,
... ).images[0]
>>> image.save("flux.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
controlnet: FluxControlNetModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_image: PipelineImageInput = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
dtype = self.transformer.dtype
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 3. Prepare control image
num_channels_latents = self.transformer.config.in_channels // 4
if isinstance(self.controlnet, FluxControlNetModel):
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
)
height, width = control_image.shape[-2:]
# vae encode
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)
noise_pred = self.transformer(
hidden_states=latents,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers.utils import export_to_gif
>>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda")
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
@@ -54,7 +54,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe = LuminaText2ImgPipeline.from_pretrained(
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
... )
... ).cuda()
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
@@ -89,44 +89,49 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []
sf_filenames = set()
passed_components = passed_components or []
# extract all components of the pipeline and their associated files
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
_, extension = os.path.splitext(filename)
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
continue
component, component_filename = filename.split("/")
if component in passed_components:
continue
if extension == ".bin":
pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
sf_filenames.add(os.path.normpath(filename))
components.setdefault(component, [])
components[component].append(component_filename)
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
filename, extension = os.path.splitext(component_filename)
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename
match_exists = extension == ".safetensors"
matches.append(match_exists)
if not any(matches):
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
return True
+6 -2
View File
@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
and not is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
elif use_safetensors and is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
@@ -602,9 +602,9 @@ class StableDiffusionKDiffusionPipeline(
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
sigmas = sigmas.to(device)
else:
sigmas = self.scheduler.sigmas
sigmas = sigmas.to(device)
sigmas = sigmas.to(prompt_embeds.dtype)
# 6. Prepare latent variables
-15
View File
@@ -182,21 +182,6 @@ class DiTTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FluxControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FluxTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -287,21 +287,6 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class FluxControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+3 -50
View File
@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import PIL.ImageOps
from .import_utils import BACKENDS_MAPPING, is_imageio_available, is_opencv_available
from .import_utils import BACKENDS_MAPPING, is_opencv_available
from .logging import get_logger
@@ -112,9 +112,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data))
def _legacy_export_to_video(
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
):
) -> str:
if is_opencv_available():
import cv2
else:
@@ -134,51 +134,4 @@ def _legacy_export_to_video(
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
return output_video_path
def export_to_video(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
) -> str:
# TODO: Dhruv. Remove by Diffusers release 0.33.0
# Added to prevent breaking existing code
if not is_imageio_available():
logger.warning(
(
"It is recommended to use `export_to_video` with `imageio` and `imageio-ffmpeg` as a backend. \n"
"These libraries are not present in your environment. Attempting to use legacy OpenCV backend to export video. \n"
"Support for the OpenCV backend will be deprecated in a future Diffusers version"
)
)
return _legacy_export_to_video(video_frames, output_video_path, fps)
if is_imageio_available():
import imageio
else:
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("export_to_video"))
try:
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
raise AttributeError(
(
"Found an existing imageio backend in your environment. Attempting to export video with imageio. \n"
"Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg"
)
)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], np.ndarray):
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
elif isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
return output_video_path
-19
View File
@@ -330,15 +330,6 @@ except importlib_metadata.PackageNotFoundError:
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
_imageio_available = importlib.util.find_spec("imageio") is not None
if _imageio_available:
try:
_imageio_version = importlib_metadata.version("imageio")
logger.debug(f"Successfully imported imageio version {_imageio_version}")
except importlib_metadata.PackageNotFoundError:
_imageio_available = False
def is_torch_available():
return _torch_available
@@ -456,10 +447,6 @@ def is_sentencepiece_available():
return _sentencepiece_available
def is_imageio_available():
return _imageio_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -588,11 +575,6 @@ BITSANDBYTES_IMPORT_ERROR = """
{0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
"""
# docstyle-ignore
IMAGEIO_IMPORT_ERROR = """
{0} requires the imageio library and ffmpeg but it was not found in your environment. You can install it with pip: `pip install imageio imageio-ffmpeg`
"""
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -617,7 +599,6 @@ BACKENDS_MAPPING = OrderedDict(
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
]
)
+14 -29
View File
@@ -1,13 +1,12 @@
import os
import tempfile
from typing import Callable, List, Optional, Union
from urllib.parse import unquote, urlparse
import PIL.Image
import PIL.ImageOps
import requests
from .import_utils import BACKENDS_MAPPING, is_imageio_available
from .import_utils import BACKENDS_MAPPING, is_opencv_available
def load_image(
@@ -81,22 +80,11 @@ def load_video(
)
if is_url:
response = requests.get(video, stream=True)
if response.status_code != 200:
raise ValueError(f"Failed to download video. Status code: {response.status_code}")
parsed_url = urlparse(video)
file_name = os.path.basename(unquote(parsed_url.path))
suffix = os.path.splitext(file_name)[1] or ".mp4"
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
video_data = requests.get(video, stream=True).raw
video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name
was_tempfile_created = True
video_data = response.iter_content(chunk_size=8192)
with open(video_path, "wb") as f:
for chunk in video_data:
f.write(chunk)
f.write(video_data.read())
video = video_path
@@ -111,22 +99,19 @@ def load_video(
pass
else:
if is_imageio_available():
import imageio
if is_opencv_available():
import cv2
else:
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video"))
try:
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
raise AttributeError(
"`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
)
video_capture = cv2.VideoCapture(video)
success, frame = video_capture.read()
while success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_images.append(PIL.Image.fromarray(frame))
success, frame = video_capture.read()
with imageio.get_reader(video) as reader:
# Read all frames
for frame in reader:
pil_images.append(PIL.Image.fromarray(frame))
video_capture.release()
if was_tempfile_created:
os.remove(video_path)
+3 -2
View File
@@ -32,7 +32,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
uses_flow_matching = True
transformer_kwargs = {
@@ -80,7 +80,8 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
Related PR: https://github.com/huggingface/diffusers/pull/8584
"""
components = self.get_dummy_components()
pipe = self.pipeline_class(**components[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

Some files were not shown because too many files have changed in this diff Show More