Compare commits
97 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 78c125f76f | |||
| 9e3b0fe107 | |||
| 878f609aa5 | |||
| 32da2e7673 | |||
| 9a0b906518 | |||
| b3428ad5f5 | |||
| 70a54a8230 | |||
| 6f4e60b58b | |||
| e4d65ccdd7 | |||
| 22dcceb858 | |||
| 9c6b8894ff | |||
| cf7369d418 | |||
| 123ecef2b9 | |||
| 511c9ef560 | |||
| fb6130fe90 | |||
| 2d9602cc96 | |||
| 1b1b737acb | |||
| 311845fc77 | |||
| 01c2dff338 | |||
| 03580c07b9 | |||
| fd11c0fbee | |||
| 92c8c00756 | |||
| 5781e017dd | |||
| 90aa8be534 | |||
| 2f1b7870e2 | |||
| 7360ea1d03 | |||
| ba1855c07e | |||
| 1b1b26b65c | |||
| 312f7dc4fd | |||
| ba4223ac3b | |||
| 6988cc3a86 | |||
| fa7fa9cced | |||
| 61c6da076a | |||
| c7ee165c4f | |||
| d99528be94 | |||
| fd0831c52c | |||
| 477e12b235 | |||
| b42b079213 | |||
| 21509aa7f5 | |||
| 65f6211f1f | |||
| 3def90523d | |||
| 551c884acd | |||
| ec53a30a0e | |||
| 71e7c82ae8 | |||
| c33dd0213b | |||
| e12458e16c | |||
| 77558f31bf | |||
| 41da084fbe | |||
| 4c2e8870e6 | |||
| fe6f5d6419 | |||
| d0b8db2b11 | |||
| 351d1f009e | |||
| a31db5f952 | |||
| 03ee7cd109 | |||
| 712ddbeac6 | |||
| 03c28eef5b | |||
| e05f83479c | |||
| bb4740ce29 | |||
| 2956866ef4 | |||
| 4498cfc98c | |||
| a449ceb3ef | |||
| 45f7127ade | |||
| 73469f9562 | |||
| d45d199b99 | |||
| e67cc5ae47 | |||
| 470815cefa | |||
| 5f183bfe27 | |||
| c43a8f5b2b | |||
| 9f9d0cbb83 | |||
| 2be7469821 | |||
| 3ae9413966 | |||
| ec9508c83b | |||
| 6bcafcbaa6 | |||
| b3052807e5 | |||
| 73b041e7a9 | |||
| 1c661ce3d4 | |||
| 8fe54bcd26 | |||
| ee40f0e1ca | |||
| 0980f4dcd2 | |||
| 71bcb1e1c5 | |||
| dfeb32975d | |||
| d83c1f8447 | |||
| 21a0fc1b0d | |||
| 16967589d8 | |||
| d963b1aaa4 | |||
| e982881716 | |||
| cb5348a0c2 | |||
| aff72ec5dc | |||
| dc7e6e814f | |||
| a3d827fb8d | |||
| 84ff56eb90 | |||
| 45cb1f92d3 | |||
| 59e6669f6d | |||
| bb917755ee | |||
| bd6efd5fe4 | |||
| c341786f3e | |||
| c8e5491be0 |
@@ -202,7 +202,6 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
|
||||
- https://github.com/microsoft/TaskMatrix
|
||||
- https://github.com/invoke-ai/InvokeAI
|
||||
- https://github.com/InstantID/InstantID
|
||||
- https://github.com/apple/ml-stable-diffusion
|
||||
- https://github.com/Sanster/lama-cleaner
|
||||
- https://github.com/IDEA-Research/Grounded-Segment-Anything
|
||||
|
||||
+56
-76
@@ -190,10 +190,6 @@
|
||||
- local: conceptual/evaluation
|
||||
title: Evaluating Diffusion Models
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: community_projects
|
||||
title: Projects built with Diffusers
|
||||
title: Community Projects
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -223,76 +219,62 @@
|
||||
sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/pixart_transformer2d
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/stable_audio_transformer
|
||||
title: StableAudioDiTModel
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/unet
|
||||
title: UNet1DModel
|
||||
- local: api/models/unet2d
|
||||
title: UNet2DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet3d-cond
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
title: UNets
|
||||
- sections:
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
title: Oobleck AutoEncoder
|
||||
- local: api/models/autoencoder_tiny
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
title: VAEs
|
||||
- local: api/models/unet
|
||||
title: UNet1DModel
|
||||
- local: api/models/unet2d
|
||||
title: UNet2DModel
|
||||
- local: api/models/unet2d-cond
|
||||
title: UNet2DConditionModel
|
||||
- local: api/models/unet3d-cond
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_tiny
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/autoencoder_oobleck
|
||||
title: Oobleck AutoEncoder
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2DModel
|
||||
- local: api/models/pixart_transformer2d
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/stable_audio_transformer
|
||||
title: StableAudioDiTModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
title: Models
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -314,8 +296,6 @@
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
@@ -53,7 +53,6 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`AutoencoderKLCogVideoX`]
|
||||
- [`ControlNetModel`]
|
||||
- [`SD3Transformer2DModel`]
|
||||
- [`FluxTransformer2DModel`]
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
|
||||
@@ -11,14 +11,18 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLCogVideoX
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
|
||||
The 3D variational autoencoder (VAE) model with KL loss using CogVideoX.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
## Loading from the original format
|
||||
|
||||
```python
|
||||
By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKLCogVideoX
|
||||
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16).to("cuda")
|
||||
url = "THUDM/CogVideoX-2b" # can also be a local file
|
||||
model = AutoencoderKLCogVideoX.from_single_file(url)
|
||||
|
||||
```
|
||||
|
||||
## AutoencoderKLCogVideoX
|
||||
@@ -28,10 +32,38 @@ vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="va
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
## CogVideoXSafeConv3d
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
[[autodoc]] CogVideoXSafeConv3d
|
||||
|
||||
## DecoderOutput
|
||||
## CogVideoXCausalConv3d
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
[[autodoc]] CogVideoXCausalConv3d
|
||||
|
||||
## CogVideoXSpatialNorm3D
|
||||
|
||||
[[autodoc]] CogVideoXSpatialNorm3D
|
||||
|
||||
## CogVideoXResnetBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXResnetBlock3D
|
||||
|
||||
## CogVideoXDownBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXDownBlock3D
|
||||
|
||||
## CogVideoXMidBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXMidBlock3D
|
||||
|
||||
## CogVideoXUpBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXUpBlock3D
|
||||
|
||||
## CogVideoXEncoder3D
|
||||
|
||||
[[autodoc]] CogVideoXEncoder3D
|
||||
|
||||
## CogVideoXDecoder3D
|
||||
|
||||
[[autodoc]] CogVideoXDecoder3D
|
||||
@@ -9,22 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# CogVideoXTransformer3DModel
|
||||
## CogVideoXTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX).
|
||||
|
||||
## CogVideoXTransformer3DModel
|
||||
|
||||
[[autodoc]] CogVideoXTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# StableCascadeUNet
|
||||
|
||||
A UNet model from the [Stable Cascade pipeline](../pipelines/stable_cascade.md).
|
||||
|
||||
## StableCascadeUNet
|
||||
|
||||
[[autodoc]] models.unets.unet_stable_cascade.StableCascadeUNet
|
||||
@@ -10,16 +10,18 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# limitations under the License.
|
||||
|
||||
## TODO: The paper is still being written.
|
||||
-->
|
||||
|
||||
# CogVideoX
|
||||
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
|
||||
[TODO]() from Tsinghua University & ZhipuAI.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.*
|
||||
The paper is still being written.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -27,9 +29,7 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
## Inference
|
||||
### Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
@@ -37,46 +37,38 @@ First, load the pipeline:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import LattePipeline
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
|
||||
pipeline = LattePipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||
|
||||
# CogVideoX works well with long and well-described prompts
|
||||
# CogVideoX works very well with long and well-described prompts
|
||||
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
```
|
||||
|
||||
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
|
||||
The [benchmark](TODO: link) results on an 80GB A100 machine are:
|
||||
|
||||
```
|
||||
Without torch.compile(): Average inference time: 96.89 seconds.
|
||||
With torch.compile(): Average inference time: 76.27 seconds.
|
||||
Without torch.compile(): Average inference time: TODO seconds.
|
||||
With torch.compile(): Average inference time: TODO seconds.
|
||||
```
|
||||
|
||||
### Memory optimization
|
||||
|
||||
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
|
||||
|
||||
- `pipe.enable_model_cpu_offload()`:
|
||||
- Without enabling cpu offloading, memory usage is `33 GB`
|
||||
- With enabling cpu offloading, memory usage is `19 GB`
|
||||
- `pipe.vae.enable_tiling()`:
|
||||
- With enabling cpu offloading and tiling, memory usage is `11 GB`
|
||||
- `pipe.vae.enable_slicing()`
|
||||
|
||||
## CogVideoXPipeline
|
||||
|
||||
[[autodoc]] CogVideoXPipeline
|
||||
@@ -84,5 +76,4 @@ CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of v
|
||||
- __call__
|
||||
|
||||
## CogVideoXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
|
||||
[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -22,16 +22,7 @@ The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:
|
||||
|
||||
|
||||
| ControlNet type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
|
||||
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
|
||||
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
|
||||
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
|
||||
|
||||
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -44,10 +35,5 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3ControlNetInpaintingPipeline
|
||||
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3PipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
|
||||
|
||||
@@ -37,7 +37,7 @@ Both checkpoints have slightly difference usage which we detail below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -61,7 +61,7 @@ out.save("image.png")
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -77,89 +77,8 @@ out = pipe(
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
## Running FP16 inference
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
FP16 inference code:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
|
||||
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.,
|
||||
height=768,
|
||||
width=1360,
|
||||
num_inference_steps=4,
|
||||
max_sequence_length=256,
|
||||
).images[0]
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
## Single File Loading for the `FluxTransformer2DModel`
|
||||
|
||||
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
<Tip>
|
||||
`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
|
||||
</Tip>
|
||||
|
||||
The following example demonstrates how to run Flux with less than 16GB of VRAM.
|
||||
|
||||
First install `optimum-quanto`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto
|
||||
```
|
||||
|
||||
Then run the following example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, FluxPipeline
|
||||
from transformers import T5EncoderModel, CLIPTextModel
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=3.5,
|
||||
output_type="pil",
|
||||
num_inference_steps=20,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
image.save("flux-fp8-dev.png")
|
||||
```
|
||||
|
||||
## FluxPipeline
|
||||
|
||||
[[autodoc]] FluxPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
@@ -43,11 +43,6 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KolorsPAGPipeline
|
||||
[[autodoc]] KolorsPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
- all
|
||||
@@ -79,12 +74,6 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- __call__
|
||||
|
||||
|
||||
## StableDiffusion3PAGPipeline
|
||||
[[autodoc]] StableDiffusion3PAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## PixArtSigmaPAGPipeline
|
||||
[[autodoc]] PixArtSigmaPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -21,7 +21,7 @@ Stable Audio is trained on a corpus of around 48k audio recordings, where around
|
||||
The abstract of the paper is the following:
|
||||
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
|
||||
|
||||
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools).
|
||||
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool).
|
||||
|
||||
## Tips
|
||||
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Community Projects
|
||||
|
||||
Welcome to Community Projects. This space is dedicated to showcasing the incredible work and innovative applications created by our vibrant community using the `diffusers` library.
|
||||
|
||||
This section aims to:
|
||||
|
||||
- Highlight diverse and inspiring projects built with `diffusers`
|
||||
- Foster knowledge sharing within our community
|
||||
- Provide real-world examples of how `diffusers` can be leveraged
|
||||
|
||||
Happy exploring, and thank you for being part of the Diffusers community!
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Project Name</th>
|
||||
<th>Description</th>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
|
||||
<td>Stable Diffusion built-in to Blender</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/megvii-research/HiDiffusion"> HiDiffusion </a></td>
|
||||
<td>Increases the resolution and speed of your diffusion model by only adding a single line of code</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/lllyasviel/IC-Light"> IC-Light </a></td>
|
||||
<td>IC-Light is a project to manipulate the illumination of images</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/InstantID/InstantID"> InstantID </a></td>
|
||||
<td>InstantID : Zero-shot Identity-Preserving Generation in Seconds</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/Sanster/IOPaint"> IOPaint </a></td>
|
||||
<td>Image inpainting tool powered by SOTA AI Model. Remove any unwanted object, defect, people from your pictures or erase and replace(powered by stable diffusion) any thing on your pictures.</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/bmaltais/kohya_ss"> Kohya </a></td>
|
||||
<td>Gradio GUI for Kohya's Stable Diffusion trainers</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/magic-research/magic-animate"> MagicAnimate </a></td>
|
||||
<td>MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/levihsu/OOTDiffusion"> OOTDiffusion </a></td>
|
||||
<td>Outfitting Fusion based Latent Diffusion for Controllable Virtual Try-on</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/vladmandic/automatic"> SD.Next </a></td>
|
||||
<td>SD.Next: Advanced Implementation of Stable Diffusion and other Diffusion-based generative image models</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/ashawkey/stable-dreamfusion"> stable-dreamfusion </a></td>
|
||||
<td>Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/HVision-NKU/StoryDiffusion"> StoryDiffusion </a></td>
|
||||
<td>StoryDiffusion can create a magic story by generating consistent images and videos.</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/cumulo-autumn/StreamDiffusion"> StreamDiffusion </a></td>
|
||||
<td>A Pipeline-Level Solution for Real-Time Interactive Generation</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -125,5 +125,3 @@ image
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion + Tiny AutoEncoder</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
|
||||
@@ -48,7 +48,7 @@ accelerate launch run_distributed.py --num_processes=2
|
||||
|
||||
<Tip>
|
||||
|
||||
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -108,4 +108,4 @@ torchrun run_distributed.py --nproc_per_node=2
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be "turn the clouds rainy" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image.
|
||||
|
||||
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use case.
|
||||
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
|
||||
|
||||
Before running the script, make sure you install the library from source:
|
||||
|
||||
@@ -117,7 +117,7 @@ optimizer = optimizer_cls(
|
||||
)
|
||||
```
|
||||
|
||||
Next, the edited images and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
|
||||
Next, the edited images and and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
|
||||
|
||||
```py
|
||||
def preprocess_train(examples):
|
||||
@@ -249,4 +249,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
|
||||
|
||||
Congratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to:
|
||||
|
||||
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
|
||||
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
|
||||
@@ -34,7 +34,7 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which lets you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Pipeline callbacks
|
||||
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
|
||||
> [!TIP]
|
||||
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
@@ -75,7 +75,7 @@ out.images[0].save("official_callback.png")
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">without SDXLCFGCutoffCallback</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a sports car at the road with cfg callback" />
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a a sports car at the road with cfg callback" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">with SDXLCFGCutoffCallback</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche
|
||||
3. Load an image processor:
|
||||
|
||||
```python
|
||||
from transformers import CLIPImageProcessor
|
||||
from transformers import CLIPFeatureExtractor
|
||||
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
||||
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
|
||||
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
|
||||
def get_depth_map(image):
|
||||
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
|
||||
|
||||
@@ -14,9 +14,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
|
||||
|
||||
This guide will show you how to merge LoRAs using the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
|
||||
This guide will show you how to merge LoRAs using the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
|
||||
|
||||
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style](https://huggingface.co/KappaNeuro/studio-ghibli-style) and [Norod78/sdxl-chalkboarddrawing-lora](https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora) LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
|
||||
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style]() and [Norod78/sdxl-chalkboarddrawing-lora]() LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -29,7 +29,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
|
||||
|
||||
## set_adapters
|
||||
|
||||
The [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
|
||||
The [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
|
||||
|
||||
```py
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
@@ -47,19 +47,19 @@ image
|
||||
## add_weighted_adapter
|
||||
|
||||
> [!WARNING]
|
||||
> This is an experimental method that adds PEFTs [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
|
||||
> This is an experimental method that adds PEFTs [`~peft.LoraModel.add_weighted_adapter`] method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
|
||||
|
||||
The [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
|
||||
The [`~peft.LoraModel.add_weighted_adapter`] method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
|
||||
|
||||
```bash
|
||||
pip install -U diffusers peft
|
||||
```
|
||||
|
||||
There are three steps to merge LoRAs with the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method:
|
||||
There are three steps to merge LoRAs with the [`~peft.LoraModel.add_weighted_adapter`] method:
|
||||
|
||||
1. Create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the underlying model and LoRA checkpoint.
|
||||
1. Create a [`~peft.PeftModel`] from the underlying model and LoRA checkpoint.
|
||||
2. Load a base UNet model and the LoRA adapters.
|
||||
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice.
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice.
|
||||
|
||||
Let's dive deeper into what these steps entail.
|
||||
|
||||
@@ -92,7 +92,7 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
```
|
||||
|
||||
Now you'll create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
|
||||
Now you'll create a [`~peft.PeftModel`] from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
|
||||
|
||||
```python
|
||||
from peft import get_peft_model, LoraConfig
|
||||
@@ -112,7 +112,7 @@ ikea_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
> [!TIP]
|
||||
> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
|
||||
|
||||
Repeat this process to create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
|
||||
Repeat this process to create a [`~peft.PeftModel`] from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
|
||||
|
||||
```python
|
||||
pipeline.delete_adapters("ikea")
|
||||
@@ -148,7 +148,7 @@ model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_saf
|
||||
model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
|
||||
```
|
||||
|
||||
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
|
||||
|
||||
> [!WARNING]
|
||||
> Keep in mind the LoRAs need to have the same rank to be merged!
|
||||
@@ -182,9 +182,9 @@ image
|
||||
|
||||
## fuse_lora
|
||||
|
||||
Both the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
|
||||
Both the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
|
||||
|
||||
@@ -199,7 +199,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
```
|
||||
|
||||
Fuse these LoRAs into the UNet with the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method because it won’t work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
Fuse these LoRAs into the UNet with the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method because it won’t work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
|
||||
```py
|
||||
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
|
||||
@@ -226,7 +226,7 @@ image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
|
||||
image
|
||||
```
|
||||
|
||||
You can call [`~~loaders.lora_base.LoraBaseMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
|
||||
You can call [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
|
||||
|
||||
```py
|
||||
pipeline.unfuse_lora()
|
||||
|
||||
@@ -307,7 +307,7 @@ print(pipeline)
|
||||
|
||||
위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.
|
||||
|
||||
- `"feature_extractor"`: [`~transformers.CLIPImageProcessor`]의 인스턴스
|
||||
- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스
|
||||
- `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)
|
||||
- `"scheduler"`: [`PNDMScheduler`]의 인스턴스
|
||||
- `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스
|
||||
|
||||
@@ -24,7 +24,7 @@ import PIL
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
|
||||
@@ -71,7 +71,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
|
||||
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -1436,9 +1435,9 @@ import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
@@ -1647,6 +1646,7 @@ from diffusers import DiffusionPipeline
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
subfolder="scheduler")
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
custom_pipeline="stable_diffusion_tensorrt_img2img",
|
||||
variant='fp16',
|
||||
@@ -1661,6 +1661,7 @@ pipe = pipe.to("cuda")
|
||||
url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png"
|
||||
response = requests.get(url)
|
||||
input_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "photorealistic new zealand hills"
|
||||
image = pipe(prompt, image=input_image, strength=0.75,).images[0]
|
||||
image.save('tensorrt_img2img_new_zealand_hills.png')
|
||||
@@ -2121,7 +2122,7 @@ import torch
|
||||
import open_clip
|
||||
from open_clip import SimpleTokenizer
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
|
||||
|
||||
def download_image(url):
|
||||
@@ -2129,7 +2130,7 @@ def download_image(url):
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
# Loading additional models
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
@@ -4208,52 +4209,6 @@ print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step,
|
||||
latency = elapsed_time(pipe4, num_inference_steps=step)
|
||||
print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
|
||||
```
|
||||
### HunyuanDiT with Differential Diffusion
|
||||
|
||||
#### Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from pipeline_hunyuandit_differential_img2img import (
|
||||
HunyuanDiTDifferentialImg2ImgPipeline,
|
||||
)
|
||||
|
||||
|
||||
pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
|
||||
source_image = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png"
|
||||
)
|
||||
map = load_image(
|
||||
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png"
|
||||
)
|
||||
prompt = "a green pear"
|
||||
negative_prompt = "blurry"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=source_image,
|
||||
num_inference_steps=28,
|
||||
guidance_scale=4.5,
|
||||
strength=1.0,
|
||||
map=map,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
|  |  |  |
|
||||
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
|
||||
| Gradient | Input | Output |
|
||||
|
||||
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
@@ -4330,4 +4285,4 @@ grid_image.save(grid_dir + "sample.png")
|
||||
|
||||
`pag_scale` : guidance scale of PAG (ex: 5.0)
|
||||
|
||||
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
|
||||
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
|
||||
@@ -7,7 +7,7 @@ import PIL.Image
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -86,7 +86,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMi
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
coca_model=None,
|
||||
coca_tokenizer=None,
|
||||
coca_transform=None,
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -32,9 +32,9 @@ EXAMPLE_DOC_STRING = """
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
@@ -139,7 +139,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from numpy import exp, pi, sqrt
|
||||
from torchvision.transforms.functional import resize
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
@@ -275,7 +275,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -15,7 +15,7 @@ from diffusers.utils import logging
|
||||
|
||||
try:
|
||||
from ligo.segments import segment
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
except ImportError:
|
||||
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
||||
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
|
||||
Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||
|
||||
# from ...configuration_utils import FrozenDict
|
||||
# from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
cc_projection ([`CCProjection`]):
|
||||
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
|
||||
@@ -102,7 +102,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
cc_projection: CCProjection,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -69,7 +69,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
@@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -100,7 +100,7 @@ class StableDiffusionIPEXPipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -693,7 +693,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -683,7 +683,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -697,7 +697,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -595,7 +595,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -609,7 +609,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
# DreamBooth training example for FLUX.1 [dev]
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
|
||||
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux"
|
||||
|
||||
accelerate launch train_dreambooth_flux.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
> [!TIP]
|
||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Text Encoder Training
|
||||
|
||||
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
|
||||
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
|
||||
|
||||
> [!NOTE]
|
||||
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
|
||||
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
|
||||
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
|
||||
|
||||
To perform DreamBooth LoRA with text-encoder training, run:
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--train_text_encoder\
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Other notes
|
||||
Thanks to `bghira` for their help with reviewing & insight sharing ♥️
|
||||
@@ -1,8 +0,0 @@
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -1,203 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothFlux(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_flux.py"
|
||||
|
||||
def test_dreambooth(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
|
||||
def test_dreambooth_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
|
||||
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
|
||||
# check new checkpoints exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -1,165 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
|
||||
|
||||
def test_dreambooth_lora_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1271,7 +1271,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -43,7 +43,7 @@ from PIL import Image
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig
|
||||
from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig
|
||||
from webdataset.tariterators import (
|
||||
base_plus_ext,
|
||||
tar_file_expander,
|
||||
@@ -205,7 +205,7 @@ class Text2ImageDataset:
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
control_type: str = "canny",
|
||||
feature_extractor: Optional[DPTImageProcessor] = None,
|
||||
feature_extractor: Optional[DPTFeatureExtractor] = None,
|
||||
):
|
||||
if not isinstance(train_shards_path_or_url, str):
|
||||
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
|
||||
@@ -1011,7 +1011,7 @@ def main(args):
|
||||
controlnet = pre_controlnet
|
||||
|
||||
if args.control_type == "depth":
|
||||
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
depth_model.requires_grad_(False)
|
||||
else:
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
" UniPCMultistepScheduler,\n",
|
||||
" EulerDiscreteScheduler,\n",
|
||||
")\n",
|
||||
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
|
||||
"\n",
|
||||
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
from PIL import Image
|
||||
from retriever import Retriever, normalize_images, preprocess_images
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -65,7 +65,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
retriever: Optional[Retriever] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig
|
||||
|
||||
from diffusers import logging
|
||||
|
||||
@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
|
||||
return images
|
||||
|
||||
|
||||
def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
|
||||
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
|
||||
"""
|
||||
Preprocesses a list of images into a batch of tensors.
|
||||
|
||||
@@ -95,12 +95,14 @@ class Index:
|
||||
def build_index(
|
||||
self,
|
||||
model=None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
torch_dtype=torch.float32,
|
||||
):
|
||||
if not self.index_initialized:
|
||||
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
|
||||
feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
|
||||
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
|
||||
self.config.clip_name_or_path
|
||||
)
|
||||
self.dataset = get_dataset_with_emb_from_clip_model(
|
||||
self.dataset,
|
||||
model,
|
||||
@@ -134,7 +136,7 @@ class Retriever:
|
||||
index: Index = None,
|
||||
dataset: Dataset = None,
|
||||
model=None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
):
|
||||
self.config = config
|
||||
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
|
||||
@@ -146,7 +148,7 @@ class Retriever:
|
||||
index: Index = None,
|
||||
dataset: Dataset = None,
|
||||
model=None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
**kwargs,
|
||||
):
|
||||
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
@@ -154,7 +156,7 @@ class Retriever:
|
||||
|
||||
@staticmethod
|
||||
def _build_index(
|
||||
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
|
||||
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
|
||||
):
|
||||
dataset = dataset or load_dataset(config.dataset_name)
|
||||
dataset = dataset[config.dataset_set]
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
@@ -69,7 +69,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed):
|
||||
# to the function and tell JAX which are static arguments, that is, arguments that
|
||||
# are known at compile time and won't change. In our case, it is num_inference_steps,
|
||||
# height, width and return_latents.
|
||||
# Once the function is compiled, these parameters are omitted from future calls and
|
||||
# Once the function is compiled, these parameters are ommited from future calls and
|
||||
# cannot be changed without modifying the code and recompiling.
|
||||
def aot_compile(
|
||||
prompt=default_prompt,
|
||||
|
||||
@@ -826,22 +826,17 @@ def main():
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@@ -871,14 +866,8 @@ def main():
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -478,7 +478,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--debug_loss",
|
||||
action="store_true",
|
||||
help="debug loss for each image, if filenames are available in the dataset",
|
||||
help="debug loss for each image, if filenames are awailable in the dataset",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
|
||||
@@ -23,25 +23,4 @@ accelerate launch textual_inversion_sdxl.py \
|
||||
--output_dir="./textual_inversion_cat_sdxl"
|
||||
```
|
||||
|
||||
Training of both text encoders is supported.
|
||||
|
||||
### Inference Example
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
|
||||
Make sure to include the `placeholder_token` in your prompt.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
model_id = "./textual_inversion_cat_sdxl"
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
prompt = "A <cat-toy> backpack"
|
||||
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("cat-backpack.png")
|
||||
|
||||
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("cat-backpack-prompt_2.png")
|
||||
```
|
||||
For now, only training of the first text encoder is supported.
|
||||
@@ -135,7 +135,7 @@ def log_validation(
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_1),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
@@ -678,54 +678,36 @@ def main():
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != args.num_vectors:
|
||||
raise ValueError(
|
||||
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
|
||||
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
|
||||
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1 or len(token_ids_2) > 1:
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
|
||||
initializer_token_id_2 = token_ids_2[0]
|
||||
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder_1.resize_token_embeddings(len(tokenizer_1))
|
||||
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder_1.get_input_embeddings().weight.data
|
||||
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
|
||||
with torch.no_grad():
|
||||
for token_id in placeholder_token_ids:
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
for token_id in placeholder_token_ids_2:
|
||||
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
text_encoder_2.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder_1.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
text_encoder_2.text_model.encoder.requires_grad_(False)
|
||||
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder_1.gradient_checkpointing_enable()
|
||||
text_encoder_2.gradient_checkpointing_enable()
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
@@ -764,11 +746,7 @@ def main():
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
# only optimize the embeddings
|
||||
[
|
||||
text_encoder_1.text_model.embeddings.token_embedding.weight,
|
||||
text_encoder_2.text_model.embeddings.token_embedding.weight,
|
||||
],
|
||||
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -808,10 +786,9 @@ def main():
|
||||
)
|
||||
|
||||
text_encoder_1.train()
|
||||
text_encoder_2.train()
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
|
||||
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_1, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
@@ -889,13 +866,11 @@ def main():
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
|
||||
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
text_encoder_1.train()
|
||||
text_encoder_2.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate([text_encoder_1, text_encoder_2]):
|
||||
with accelerator.accumulate(text_encoder_1):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
@@ -917,7 +892,9 @@ def main():
|
||||
.hidden_states[-2]
|
||||
.to(dtype=weight_dtype)
|
||||
)
|
||||
encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
|
||||
encoder_output_2 = text_encoder_2(
|
||||
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
|
||||
)
|
||||
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
|
||||
original_size = [
|
||||
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
|
||||
@@ -961,16 +938,11 @@ def main():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
|
||||
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
|
||||
index_no_updates_2
|
||||
] = orig_embeds_params_2[index_no_updates_2]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -988,16 +960,6 @@ def main():
|
||||
save_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
|
||||
save_path = os.path.join(args.output_dir, weight_name)
|
||||
save_progress(
|
||||
text_encoder_2,
|
||||
placeholder_token_ids_2,
|
||||
accelerator,
|
||||
args,
|
||||
save_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
@@ -1072,7 +1034,7 @@ def main():
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_1),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer_1,
|
||||
@@ -1090,16 +1052,6 @@ def main():
|
||||
save_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
weight_name = "learned_embeds_2.safetensors"
|
||||
save_path = os.path.join(args.output_dir, weight_name)
|
||||
save_progress(
|
||||
text_encoder_2,
|
||||
placeholder_token_ids_2,
|
||||
accelerator,
|
||||
args,
|
||||
save_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
|
||||
@@ -12,7 +12,6 @@ from .utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
@@ -88,7 +87,6 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -252,10 +250,11 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
"FluxPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -287,6 +286,8 @@ else:
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"KolorsPipeline",
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LattePipeline",
|
||||
@@ -310,11 +311,9 @@ else:
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3ControlNetInpaintingPipeline",
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
@@ -392,19 +391,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -552,7 +538,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -694,10 +679,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
CLIPImageProjection,
|
||||
CogVideoXPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -729,6 +715,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LattePipeline,
|
||||
@@ -755,7 +743,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
@@ -827,13 +814,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -24,7 +24,6 @@ from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
@@ -75,13 +74,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,12 +74,9 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -113,10 +110,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -498,13 +491,7 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "sd3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
|
||||
model_type = "animatediff_scribble"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
|
||||
model_type = "animatediff_rgb"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
model_type = "animatediff_v2"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
|
||||
@@ -516,11 +503,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
|
||||
if "guidance_in.in_layer.bias" in checkpoint:
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1877,195 +1859,3 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
mlp_ratio = 4.0
|
||||
inner_dim = 3072
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"time_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"time_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"vector_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
|
||||
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in checkpoint)
|
||||
if has_guidance:
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.bias"
|
||||
)
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
||||
|
||||
# x_embedder
|
||||
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
||||
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# norms.
|
||||
## norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.bias"
|
||||
)
|
||||
## norm1_context
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.bias"
|
||||
)
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
# qk_norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||
)
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.bias"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.bias"
|
||||
)
|
||||
|
||||
# single transfomer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.bias"
|
||||
)
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
||||
# qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.key_norm.scale"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
||||
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -35,7 +35,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
@@ -88,7 +87,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import SparseControlNetModel
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -272,17 +272,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
@@ -387,7 +376,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
@@ -793,319 +782,6 @@ class SkipFFTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FreeNoiseTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A FreeNoise Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (`int`, *optional*):
|
||||
The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, defaults to `False`):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, defaults to `False`):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
ff_inner_dim (`int`, *optional*):
|
||||
Hidden dimension of feed-forward MLP.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in feed-forward MLP.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in attention output project layer.
|
||||
context_length (`int`, defaults to `16`):
|
||||
The maximum number of frames that the FreeNoise block processes at once.
|
||||
context_stride (`int`, defaults to `4`):
|
||||
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
||||
weighting_scheme (`str`, defaults to `"pyramid"`):
|
||||
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
||||
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
|
||||
used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
context_length: int = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
window_end = min(num_frames, i + self.context_length)
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
return weights
|
||||
|
||||
def set_free_noise_properties(
|
||||
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
||||
) -> None:
|
||||
self.context_length = context_length
|
||||
self.context_stride = context_stride
|
||||
self.weighting_scheme = weighting_scheme
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
# hidden_states: [B x H x W, F, C]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
num_frames = hidden_states.size(1)
|
||||
frame_indices = self._get_frame_indices(num_frames)
|
||||
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
||||
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
||||
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
||||
|
||||
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
||||
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
||||
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
||||
if not is_last_frame_batch_complete:
|
||||
if num_frames < self.context_length:
|
||||
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
||||
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
||||
frame_indices.append((num_frames - self.context_length, num_frames))
|
||||
|
||||
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
||||
accumulated_values = torch.zeros_like(hidden_states)
|
||||
|
||||
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
||||
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
||||
# essentially a non-multiple of `context_length`.
|
||||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||
weights *= frame_weights
|
||||
|
||||
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
if hidden_states_chunk.ndim == 4:
|
||||
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
|
||||
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
||||
accumulated_values[:, -last_frame_batch_length:] += (
|
||||
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
||||
)
|
||||
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
||||
else:
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
hidden_states = torch.where(
|
||||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
).to(dtype)
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
@@ -227,7 +227,6 @@ class Attention(nn.Module):
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
@@ -699,15 +698,12 @@ class Attention(nn.Module):
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
@@ -1106,326 +1102,6 @@ class JointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# store the length of image patch sequences to create a mask that prevents interaction between patches
|
||||
# similar to making the self-attention map an identity matrix
|
||||
identity_block_size = hidden_states.shape[1]
|
||||
|
||||
# chunk
|
||||
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
||||
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGCFGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
identity_block_size = hidden_states.shape[
|
||||
1
|
||||
] # patch embeddings width * height (correspond to self-attention map width or height)
|
||||
|
||||
# chunk
|
||||
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
||||
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
||||
|
||||
(
|
||||
encoder_hidden_states_uncond,
|
||||
encoder_hidden_states_org,
|
||||
encoder_hidden_states_ptb,
|
||||
) = encoder_hidden_states.chunk(3)
|
||||
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -1598,103 +1274,6 @@ class AuraFlowAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAuraFlowAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -22,7 +7,6 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..downsampling import CogVideoXDownsample3D
|
||||
@@ -32,11 +16,8 @@ from ..upsampling import CogVideoXUpsample3D
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
r"""
|
||||
"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
||||
"""
|
||||
|
||||
@@ -68,12 +49,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input tensor.
|
||||
out_channels (`int`): Number of output channels produced by the convolution.
|
||||
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
stride (`int`, defaults to `1`): Stride of the convolution.
|
||||
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
||||
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
||||
in_channels (int): Number of channels in the input tensor.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
|
||||
stride (int, optional): Stride of the convolution. Default is 1.
|
||||
dilation (int, optional): Dilation rate of the convolution. Default is 1.
|
||||
pad_mode (str, optional): Padding mode. Default is "constant".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -117,31 +98,35 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
|
||||
self.conv_cache = None
|
||||
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
dim = self.temporal_dim
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
if kernel_size == 1:
|
||||
return inputs
|
||||
|
||||
inputs = inputs.transpose(0, dim)
|
||||
|
||||
if self.conv_cache is not None:
|
||||
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
|
||||
else:
|
||||
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
|
||||
|
||||
inputs = inputs.transpose(0, dim).contiguous()
|
||||
return inputs
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True):
|
||||
input_parallel = self.fake_cp_pass_from_previous_rank(inputs)
|
||||
|
||||
del self.conv_cache
|
||||
self.conv_cache = None
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
if not clear_fake_cp_cache:
|
||||
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
@@ -157,18 +142,15 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
||||
zq_channels (`int`):
|
||||
The number of channels for the quantized vector as described in the paper.
|
||||
groups (`int`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
@@ -193,26 +175,17 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
A 3D ResNet block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
non_linearity (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
conv_shortcut (bool, defaults to `False`):
|
||||
Whether or not to use a convolution shortcut.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (Optional[int], optional):
|
||||
Number of output channels. If None, defaults to `in_channels`. Default is None.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
temb_channels (int, optional): Number of time embedding channels. Default is 512.
|
||||
groups (int, optional): Number of groups for group normalization. Default is 32.
|
||||
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
|
||||
non_linearity (str, optional): Activation function to use. Default is "swish".
|
||||
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -244,12 +217,10 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
self.norm1 = CogVideoXSpatialNorm3D(
|
||||
f_channels=in_channels,
|
||||
zq_channels=spatial_norm_dim,
|
||||
groups=groups,
|
||||
)
|
||||
self.norm2 = CogVideoXSpatialNorm3D(
|
||||
f_channels=out_channels,
|
||||
zq_channels=spatial_norm_dim,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
self.conv1 = CogVideoXCausalConv3d(
|
||||
@@ -279,16 +250,15 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = True,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm1(hidden_states, zq)
|
||||
else:
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
if temb is not None:
|
||||
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
@@ -300,13 +270,16 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
if self.use_conv_shortcut:
|
||||
inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
else:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
|
||||
hidden_states = hidden_states + inputs
|
||||
return hidden_states
|
||||
output_tensor = inputs + hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
class CogVideoXDownBlock3D(nn.Module):
|
||||
@@ -314,28 +287,18 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
A downsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
add_downsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
|
||||
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -392,6 +355,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -403,10 +367,10 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
@@ -420,24 +384,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
A middle block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -480,6 +435,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -491,10 +447,10 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -504,30 +460,19 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
An upsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, defaults to `16`):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
add_upsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
|
||||
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
|
||||
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -577,13 +522,12 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
for resnet in self.resnets:
|
||||
@@ -596,10 +540,10 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -622,12 +566,14 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -705,9 +651,11 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -720,25 +668,25 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, temb, None
|
||||
create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states, temb, None)
|
||||
hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, None)
|
||||
hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -756,12 +704,14 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
norm_type (`str`, *optional*, defaults to `"group"`):
|
||||
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -838,7 +788,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
||||
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CogVideoXCausalConv3d(
|
||||
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
||||
@@ -846,9 +796,11 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
hidden_states = self.conv_in(sample)
|
||||
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -860,33 +812,32 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
# 1. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, temb, sample
|
||||
create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample)
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, sample)
|
||||
hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
||||
[CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
A VAE model with KL loss for encodfing images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
@@ -913,6 +864,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
||||
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
||||
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
||||
mid_block will only have resnet blocks
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -942,8 +896,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_num_groups: int = 32,
|
||||
temporal_compression_ratio: float = 4,
|
||||
sample_height: int = 480,
|
||||
sample_width: int = 720,
|
||||
sample_size: int = 256,
|
||||
scaling_factor: float = 1.15258426,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
@@ -951,6 +904,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
mid_block_add_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -982,108 +936,22 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||
# If you decode X latent frames together, the number of output frames is:
|
||||
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
||||
#
|
||||
# Example with num_latent_frames_batch_size = 2:
|
||||
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
||||
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 6 * 8 = 48 frames
|
||||
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
||||
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
||||
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 1 * 9 + 5 * 8 = 49 frames
|
||||
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
||||
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
||||
# number of temporal frames.
|
||||
self.num_latent_frames_batch_size = 2
|
||||
|
||||
# We make the minimum height and width of sample for tiling half that of the generally supported
|
||||
self.tile_sample_min_height = sample_height // 2
|
||||
self.tile_sample_min_width = sample_width // 2
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
|
||||
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
||||
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
||||
# and so the tiling implementation has only been tested on those specific resolutions.
|
||||
self.tile_overlap_factor_height = 1 / 6
|
||||
self.tile_overlap_factor_width = 1 / 5
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_overlap_factor_height: Optional[float] = None,
|
||||
tile_overlap_factor_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_overlap_factor_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
tile_overlap_factor_width (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
||||
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
@@ -1092,12 +960,14 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
fake_cp (`bool`, *optional*, defaults to `True`):
|
||||
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self.encoder(x)
|
||||
h = self.encoder(x, clear_fake_cp_cache=not fake_cp)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
@@ -1105,34 +975,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
dec = []
|
||||
for i in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1140,116 +986,20 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
fake_cp (`bool`, *optional*, defaults to `True`):
|
||||
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# Rough memory assessment:
|
||||
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
||||
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
||||
# - Assume fp16 (2 bytes per value).
|
||||
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
||||
#
|
||||
# Memory assessment when using tiling:
|
||||
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
||||
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.decoder(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)
|
||||
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, clear_fake_cp_cache=not fake_cp)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,374 +0,0 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
||||
from .modeling_outputs import Transformer2DModelOutput
|
||||
from .transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
self.time_text_embed = text_time_guidance_cls(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers=4,
|
||||
num_single_layers=10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
#
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return FluxControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
@@ -55,7 +55,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
@@ -99,7 +98,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels + extra_conditioning_channels,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -93,7 +92,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
|
||||
Models](https://arxiv.org/abs/2311.16933).
|
||||
@@ -315,7 +314,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
elif down_block_type == "DownBlockMotion":
|
||||
down_block = DownBlockMotion(
|
||||
@@ -333,7 +331,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -285,7 +285,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
upcast_attention (`bool`, defaults to `True`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
max_norm_num_groups (`int`, defaults to 32):
|
||||
Maximum number of groups in group normal. The actual number will be the largest divisor of the respective
|
||||
Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
|
||||
channels, that is <= max_norm_num_groups.
|
||||
"""
|
||||
|
||||
|
||||
@@ -773,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else index_file,
|
||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
@@ -803,7 +803,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else index_file,
|
||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
|
||||
@@ -34,11 +34,7 @@ class AdaLayerNorm(nn.Module):
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -53,13 +49,14 @@ class AdaLayerNorm(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
@@ -71,10 +68,7 @@ class AdaLayerNorm(nn.Module):
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
|
||||
@@ -22,12 +22,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -325,106 +320,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -34,37 +34,37 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -151,56 +151,39 @@ class CogVideoXBlock(nn.Module):
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
||||
`num_embeds_ada_norm`.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
||||
caption_channels (`int`, *optional*):
|
||||
The number of channels in the caption embeddings.
|
||||
video_length (`int`, *optional*):
|
||||
The number of frames in the video-like data.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -210,7 +193,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
in_channels: Optional[int] = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
@@ -328,7 +311,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||
|
||||
# 4. Transformer blocks
|
||||
# 5. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -355,11 +338,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# 5. Final block
|
||||
# 6. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
# 7. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -247,46 +247,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,13 +15,12 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -126,8 +125,6 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -226,13 +223,11 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
@@ -322,8 +317,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
@@ -380,7 +373,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
@@ -414,12 +406,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
@@ -450,15 +436,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
@@ -233,7 +233,6 @@ class DownBlockMotion(nn.Module):
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -283,7 +282,6 @@ class DownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads[i],
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -345,7 +343,6 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
@@ -387,7 +384,6 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -469,7 +465,6 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads,
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -541,7 +536,6 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -767,7 +761,6 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -928,9 +921,9 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
@@ -1930,6 +1923,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -1959,6 +1953,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
|
||||
@@ -10,7 +10,6 @@ from ..utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
@@ -124,7 +123,7 @@ else:
|
||||
"AnimateDiffSparseControlNetPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
]
|
||||
_import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"]
|
||||
_import_structure["flux"] = ["FluxPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -147,9 +146,7 @@ else:
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"AnimateDiffPAGPipeline",
|
||||
"KolorsPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
@@ -173,7 +170,6 @@ else:
|
||||
_import_structure["controlnet_sd3"].extend(
|
||||
[
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3ControlNetInpaintingPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
@@ -210,6 +206,12 @@ else:
|
||||
"Kandinsky3Img2ImgPipeline",
|
||||
"Kandinsky3Pipeline",
|
||||
]
|
||||
_import_structure["kolors"] = [
|
||||
"KolorsPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
]
|
||||
_import_structure["latent_consistency_models"] = [
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
@@ -349,22 +351,6 @@ else:
|
||||
"StableDiffusionKDiffusionPipeline",
|
||||
"StableDiffusionXLKDiffusionPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import (
|
||||
dummy_torch_and_transformers_and_sentencepiece_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
|
||||
else:
|
||||
_import_structure["kolors"] = [
|
||||
"KolorsPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -466,7 +452,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .controlnet_hunyuandit import (
|
||||
HunyuanDiTControlNetPipeline,
|
||||
)
|
||||
from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
|
||||
from .controlnet_sd3 import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from .controlnet_xs import (
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
@@ -493,7 +481,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
from .flux import FluxControlNetPipeline, FluxPipeline
|
||||
from .flux import FluxPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .i2vgen_xl import I2VGenXLPipeline
|
||||
from .kandinsky import (
|
||||
@@ -521,6 +509,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3Img2ImgPipeline,
|
||||
Kandinsky3Pipeline,
|
||||
)
|
||||
from .kolors import (
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
)
|
||||
from .latent_consistency_models import (
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
@@ -542,9 +536,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pag import (
|
||||
AnimateDiffPAGPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
KolorsPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -652,17 +644,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLKDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import *
|
||||
else:
|
||||
from .kolors import (
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -42,7 +42,6 @@ from ...utils import (
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -73,7 +72,6 @@ class AnimateDiffPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -396,20 +394,15 @@ class AnimateDiffPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -502,21 +495,10 @@ class AnimateDiffPipeline(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -524,6 +506,11 @@ class AnimateDiffPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -582,7 +569,6 @@ class AnimateDiffPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -651,8 +637,6 @@ class AnimateDiffPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
decode_chunk_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -824,7 +808,7 @@ class AnimateDiffPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -30,7 +30,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -110,7 +109,6 @@ class AnimateDiffControlNetPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation with ControlNet guidance.
|
||||
@@ -434,16 +432,15 @@ class AnimateDiffControlNetPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
for i in range(0, latents.shape[0], decode_batch_size):
|
||||
batch_latents = latents[i : i + decode_batch_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
@@ -611,22 +608,10 @@ class AnimateDiffControlNetPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -634,6 +619,11 @@ class AnimateDiffControlNetPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -728,7 +718,7 @@ class AnimateDiffControlNetPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
decode_batch_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -1064,7 +1054,7 @@ class AnimateDiffControlNetPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -35,7 +35,6 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -177,7 +176,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
@@ -500,29 +498,15 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
|
||||
latents = []
|
||||
for i in range(0, len(video), decode_chunk_size):
|
||||
batch_video = video[i : i + decode_chunk_size]
|
||||
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
|
||||
latents.append(batch_video)
|
||||
return torch.cat(latents)
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -638,7 +622,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
if latents is None:
|
||||
num_frames = video.shape[1]
|
||||
@@ -673,11 +656,13 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
||||
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
@@ -762,7 +747,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -838,8 +822,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
decode_chunk_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -941,7 +923,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@@ -1009,7 +990,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -18,7 +18,6 @@ from collections import OrderedDict
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
@@ -48,11 +47,11 @@ from .kandinsky2_2 import (
|
||||
KandinskyV22Pipeline,
|
||||
)
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -85,7 +84,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion", StableDiffusionPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3Pipeline),
|
||||
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
|
||||
("if", IFPipeline),
|
||||
("hunyuan", HunyuanDiTPipeline),
|
||||
("hunyuan-pag", HunyuanDiTPAGPipeline),
|
||||
@@ -105,6 +103,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
|
||||
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("kolors", KolorsPipeline),
|
||||
("flux", FluxPipeline),
|
||||
]
|
||||
)
|
||||
@@ -122,6 +121,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("kolors", KolorsImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -160,14 +160,6 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .kolors import KolorsPipeline
|
||||
from .pag import KolorsPAGPipeline
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
|
||||
SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
|
||||
@@ -36,11 +36,10 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
@@ -49,7 +48,9 @@ EXAMPLE_DOC_STRING = """
|
||||
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
... "atmosphere of this unique musical performance."
|
||||
... )
|
||||
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
>>> video = pipe(
|
||||
... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
@@ -153,7 +154,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
@@ -181,6 +182,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@@ -210,7 +214,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
@@ -332,11 +336,20 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
frames = []
|
||||
for i in range(num_seconds):
|
||||
# Whether or not to clear fake context parallel cache
|
||||
fake_cp = i + 1 < num_seconds
|
||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||
|
||||
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample
|
||||
frames.append(current_frames)
|
||||
|
||||
frames = torch.cat(frames, dim=2)
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
@@ -429,11 +442,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 49,
|
||||
num_frames: int = 48,
|
||||
fps: int = 8,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@@ -446,7 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
use_dynamic_cfg: bool = False,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -512,9 +525,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -524,10 +534,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
assert (
|
||||
num_frames <= 48 and num_frames % fps == 0 and fps == 8
|
||||
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@@ -572,7 +581,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
@@ -584,6 +592,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
num_frames += 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
@@ -662,8 +671,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
if not output_type == "latents":
|
||||
video = self.decode_latents(latents, num_frames // fps)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
@@ -76,13 +76,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from transformers import DPTImageProcessor, DPTForDepthEstimation
|
||||
>>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
... "diffusers/controlnet-depth-sdxl-1.0-small",
|
||||
... variant="fp16",
|
||||
|
||||
@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
@@ -149,7 +149,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -23,9 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
|
||||
"StableDiffusion3ControlNetInpaintingPipeline"
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -36,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
|
||||
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
|
||||
-1144
File diff suppressed because it is too large
Load Diff
+4
-4
@@ -16,7 +16,7 @@ import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ....image_processor import VaeImageProcessor
|
||||
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
@@ -66,8 +66,8 @@ class StableDiffusionModelEditingPipeline(
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
feature_extractor ([`~transformers.CLIPFeatureExtractor`]):
|
||||
A `CLIPFeatureExtractor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
with_to_k ([`bool`]):
|
||||
Whether to edit the key projection matrices along with the value projection matrices.
|
||||
with_augs ([`list`]):
|
||||
@@ -86,7 +86,7 @@ class StableDiffusionModelEditingPipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
with_to_k: bool = True,
|
||||
with_augs: list = AUGS_CONST,
|
||||
|
||||
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flux import FluxPipeline
|
||||
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,861 +0,0 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import FluxPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
||||
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... control_image=control_image,
|
||||
... controlnet_conditioning_scale=0.6,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... ).images[0]
|
||||
>>> image.save("flux.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Args:
|
||||
transformer ([`FluxTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
controlnet: FluxControlNetModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = self.text_encoder_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
# We only use the pooled prompt output from the CLIPTextModel
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
control_image: PipelineImageInput = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 3. Prepare control image
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
if isinstance(self.controlnet, FluxControlNetModel):
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([guidance_scale], device=device)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# controlnet
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
@@ -1,236 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
||||
from ..models.unets.unet_motion_model import (
|
||||
CrossAttnDownBlockMotion,
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
)
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AnimateDiffFreeNoiseMixin:
|
||||
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
||||
|
||||
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to enable FreeNoise in transformer blocks."""
|
||||
|
||||
for motion_module in block.motion_modules:
|
||||
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||
|
||||
for i in range(num_transformer_blocks):
|
||||
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||
motion_module.transformer_blocks[i].set_free_noise_properties(
|
||||
self._free_noise_context_length,
|
||||
self._free_noise_context_stride,
|
||||
self._free_noise_weighting_scheme,
|
||||
)
|
||||
else:
|
||||
assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
|
||||
basic_transfomer_block = motion_module.transformer_blocks[i]
|
||||
|
||||
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
|
||||
dim=basic_transfomer_block.dim,
|
||||
num_attention_heads=basic_transfomer_block.num_attention_heads,
|
||||
attention_head_dim=basic_transfomer_block.attention_head_dim,
|
||||
dropout=basic_transfomer_block.dropout,
|
||||
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
|
||||
activation_fn=basic_transfomer_block.activation_fn,
|
||||
attention_bias=basic_transfomer_block.attention_bias,
|
||||
only_cross_attention=basic_transfomer_block.only_cross_attention,
|
||||
double_self_attention=basic_transfomer_block.double_self_attention,
|
||||
positional_embeddings=basic_transfomer_block.positional_embeddings,
|
||||
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
|
||||
context_length=self._free_noise_context_length,
|
||||
context_stride=self._free_noise_context_stride,
|
||||
weighting_scheme=self._free_noise_weighting_scheme,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
basic_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
|
||||
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
||||
r"""Helper function to disable FreeNoise in transformer blocks."""
|
||||
|
||||
for motion_module in block.motion_modules:
|
||||
num_transformer_blocks = len(motion_module.transformer_blocks)
|
||||
|
||||
for i in range(num_transformer_blocks):
|
||||
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
||||
free_noise_transfomer_block = motion_module.transformer_blocks[i]
|
||||
|
||||
motion_module.transformer_blocks[i] = BasicTransformerBlock(
|
||||
dim=free_noise_transfomer_block.dim,
|
||||
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
|
||||
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
|
||||
dropout=free_noise_transfomer_block.dropout,
|
||||
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
|
||||
activation_fn=free_noise_transfomer_block.activation_fn,
|
||||
attention_bias=free_noise_transfomer_block.attention_bias,
|
||||
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
|
||||
double_self_attention=free_noise_transfomer_block.double_self_attention,
|
||||
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
|
||||
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
motion_module.transformer_blocks[i].load_state_dict(
|
||||
free_noise_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
|
||||
def _prepare_latents_free_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
context_num_frames = (
|
||||
self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
context_num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if self._free_noise_noise_type == "random":
|
||||
return latents
|
||||
else:
|
||||
if latents.size(2) == num_frames:
|
||||
return latents
|
||||
elif latents.size(2) != self._free_noise_context_length:
|
||||
raise ValueError(
|
||||
f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
if self._free_noise_noise_type == "shuffle_context":
|
||||
for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
|
||||
# ensure window is within bounds
|
||||
window_start = max(0, i - self._free_noise_context_length)
|
||||
window_end = min(num_frames, window_start + self._free_noise_context_stride)
|
||||
window_length = window_end - window_start
|
||||
|
||||
if window_length == 0:
|
||||
break
|
||||
|
||||
indices = torch.LongTensor(list(range(window_start, window_end)))
|
||||
shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
|
||||
|
||||
current_start = i
|
||||
current_end = min(num_frames, current_start + window_length)
|
||||
if current_end == current_start + window_length:
|
||||
# batch of frames perfectly fits the window
|
||||
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
|
||||
else:
|
||||
# handle the case where the last batch of frames does not fit perfectly with the window
|
||||
prefix_length = current_end - current_start
|
||||
shuffled_indices = shuffled_indices[:prefix_length]
|
||||
latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
|
||||
|
||||
elif self._free_noise_noise_type == "repeat_context":
|
||||
num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
|
||||
latents = torch.cat([latents] * num_repeats, dim=2)
|
||||
|
||||
latents = latents[:, :, :num_frames]
|
||||
return latents
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
|
||||
Args:
|
||||
context_length (`int`, defaults to `16`, *optional*):
|
||||
The number of video frames to process at once. It's recommended to set this to the maximum frames the
|
||||
Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
|
||||
adapter config is used.
|
||||
context_stride (`int`, *optional*):
|
||||
Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
|
||||
windows of size `context_length`. Context stride allows you to specify how many frames to skip between
|
||||
each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
|
||||
[0, 15], [4, 19], [8, 23] (0-based indexing)
|
||||
weighting_scheme (`str`, defaults to `pyramid`):
|
||||
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
|
||||
schemes are supported currently:
|
||||
- "pyramid"
|
||||
Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
||||
noise_type (`str`, defaults to "shuffle_context"):
|
||||
TODO
|
||||
"""
|
||||
|
||||
allowed_weighting_scheme = ["pyramid"]
|
||||
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
||||
|
||||
if context_length > self.motion_adapter.config.motion_max_seq_length:
|
||||
logger.warning(
|
||||
f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
|
||||
)
|
||||
if weighting_scheme not in allowed_weighting_scheme:
|
||||
raise ValueError(
|
||||
f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
|
||||
)
|
||||
if noise_type not in allowed_noise_type:
|
||||
raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
|
||||
|
||||
self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._enable_free_noise_in_block(block)
|
||||
|
||||
def disable_free_noise(self) -> None:
|
||||
self._free_noise_context_length = None
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
@property
|
||||
def free_noise_enabled(self):
|
||||
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
||||
@@ -5,7 +5,6 @@ from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -15,12 +14,12 @@ _dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available():
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_kolors"] = ["KolorsPipeline"]
|
||||
_import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"]
|
||||
@@ -29,10 +28,10 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available():
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import *
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .pipeline_kolors import KolorsPipeline
|
||||
|
||||
@@ -143,18 +143,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@unk_token.setter
|
||||
def unk_token(self, value: str):
|
||||
self._unk_token = value
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@pad_token.setter
|
||||
def pad_token(self, value: str):
|
||||
self._pad_token = value
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
@@ -163,10 +155,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@eos_token.setter
|
||||
def eos_token(self, value: str):
|
||||
self._eos_token = value
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@@ -25,10 +25,8 @@ else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
|
||||
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
|
||||
@@ -45,10 +43,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
from .pipeline_pag_kolors import KolorsPAGPipeline
|
||||
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
||||
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
|
||||
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
|
||||
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,985 +0,0 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
||||
from .pag_utils import PAGMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import AutoPipelineForText2Image
|
||||
|
||||
>>> pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
... torch_dtype=torch.float16,
|
||||
... enable_pag=True,
|
||||
... pag_applied_layers=["blocks.13"],
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0]
|
||||
>>> image.save("sd3_pag.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin):
|
||||
r"""
|
||||
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
|
||||
using Stable Diffusion 3.
|
||||
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
||||
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
||||
as its dimension.
|
||||
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
text_encoder_3 ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: SD3Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
|
||||
self.set_pag_applied_layers(
|
||||
pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0())
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if self.text_encoder_3 is None:
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer_3(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = self.text_encoder_3.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
clip_model_index: int = 0,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
tokenizer = clip_tokenizers[clip_model_index]
|
||||
text_encoder = clip_text_encoders[clip_model_index]
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
||||
|
||||
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
||||
negative_clip_prompt_embeds,
|
||||
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
||||
)
|
||||
|
||||
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
pag_scale: float = 3.0,
|
||||
pag_adaptive_scale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
||||
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
pag_scale (`float`, *optional*, defaults to 3.0):
|
||||
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
|
||||
guidance will not be used.
|
||||
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
|
||||
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
|
||||
used.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scale = pag_adaptive_scale #
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
prompt_embeds = self._prepare_perturbed_attention_guidance(
|
||||
prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
|
||||
)
|
||||
pooled_prompt_embeds = self._prepare_perturbed_attention_guidance(
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance
|
||||
)
|
||||
elif self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
original_attn_proc = self.transformer.attn_processors
|
||||
self._set_pag_attn_processor(
|
||||
pag_applied_layers=self.pag_applied_layers,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
|
||||
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
)
|
||||
|
||||
elif self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if self.do_perturbed_attention_guidance:
|
||||
self.transformer.set_attn_processor(original_attn_proc)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusion3PipelineOutput(images=image)
|
||||
@@ -35,7 +35,6 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..animatediff.pipeline_output import AnimateDiffPipelineOutput
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pag_utils import PAGMixin
|
||||
|
||||
@@ -84,7 +83,6 @@ class AnimateDiffPAGPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
PAGMixin,
|
||||
):
|
||||
r"""
|
||||
@@ -406,21 +404,15 @@ class AnimateDiffPAGPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -507,22 +499,10 @@ class AnimateDiffPAGPipeline(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -530,6 +510,11 @@ class AnimateDiffPAGPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -588,7 +573,6 @@ class AnimateDiffPAGPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
pag_scale: float = 3.0,
|
||||
pag_adaptive_scale: float = 0.0,
|
||||
):
|
||||
@@ -847,7 +831,7 @@ class AnimateDiffPAGPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
@@ -111,7 +111,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
depth_estimator: DPTForDepthEstimation,
|
||||
feature_extractor: DPTImageProcessor,
|
||||
feature_extractor: DPTFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
@@ -138,7 +138,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
+2
-2
@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
@@ -193,7 +193,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
@@ -209,7 +209,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -225,7 +225,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -237,7 +237,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->CogVideoXDDIM
|
||||
class CogVideoXDDIMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
@@ -312,7 +312,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
) -> Union[CogVideoXDDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -337,12 +337,13 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or
|
||||
`tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
[`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
@@ -396,7 +397,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
return CogVideoXDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
|
||||
@@ -30,8 +30,8 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->CogVideoX
|
||||
class CogVideoXDPMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
@@ -93,6 +93,7 @@ def betas_for_alpha_bar(
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim_cogvideox.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(alphas_cumprod):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
@@ -231,16 +232,6 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -339,7 +330,7 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
) -> Union[CogVideoXDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -364,12 +355,13 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
||||
Whether or not to return a [`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMDPMSchedulerOutput`] or
|
||||
`tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
[`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_dpm_cogvideox.CogVideoXDPMSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
@@ -436,7 +428,7 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_original_sample)
|
||||
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
return CogVideoXDPMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
|
||||
@@ -78,7 +78,6 @@ from .import_utils import (
|
||||
is_peft_version,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_tensorboard_available,
|
||||
is_timm_available,
|
||||
is_torch_available,
|
||||
|
||||
@@ -182,21 +182,6 @@ class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class KolorsImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
|
||||
class KolorsPAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
|
||||
class KolorsPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "sentencepiece"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "sentencepiece"])
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user