Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f63c12633f | |||
| be5995a815 | |||
| 065978474b | |||
| cc1e589537 | |||
| 8b9bfaea80 | |||
| b12c7f8390 | |||
| 06f36713ae | |||
| 19c5d7b376 | |||
| 99a64aa63c | |||
| 1bb419672d | |||
| a655574710 | |||
| 67a80dfbd5 | |||
| 1f77300d23 | |||
| 8a79d8ec39 | |||
| 2dad462d9b | |||
| e3568d14ba | |||
| f6df22447c | |||
| 9b5180cb5f | |||
| 16a93f1a25 | |||
| 2d753b6fb5 | |||
| 39e1f7eaa4 | |||
| e1b603dc2e | |||
| e4325606db | |||
| 926daa30f9 | |||
| 325a5de3a9 | |||
| 4c6152c2fb | |||
| 87e50a2f1d | |||
| a57a7af45c | |||
| 52f1378e64 | |||
| 3dc97bd148 | |||
| 6d32b29239 | |||
| bc3c73ad0b | |||
| 5934873b8f |
@@ -190,6 +190,10 @@
|
||||
- local: conceptual/evaluation
|
||||
title: Evaluating Diffusion Models
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: community_projects
|
||||
title: Projects built with Diffusers
|
||||
title: Community Projects
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -235,8 +239,12 @@
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/stable_cascade_unet
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/autoencoder_tiny
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -257,6 +265,8 @@
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
@@ -296,6 +306,8 @@
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
@@ -53,6 +53,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`AutoencoderKLCogVideoX`]
|
||||
- [`ControlNetModel`]
|
||||
- [`SD3Transformer2DModel`]
|
||||
- [`FluxTransformer2DModel`]
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
|
||||
@@ -11,18 +11,14 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLCogVideoX
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss using CogVideoX.
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
|
||||
|
||||
## Loading from the original format
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
```python
|
||||
from diffusers import AutoencoderKLCogVideoX
|
||||
|
||||
url = "THUDM/CogVideoX-2b" # can also be a local file
|
||||
model = AutoencoderKLCogVideoX.from_single_file(url)
|
||||
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLCogVideoX
|
||||
@@ -32,38 +28,10 @@ model = AutoencoderKLCogVideoX.from_single_file(url)
|
||||
- encode
|
||||
- all
|
||||
|
||||
## CogVideoXSafeConv3d
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] CogVideoXSafeConv3d
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## CogVideoXCausalConv3d
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] CogVideoXCausalConv3d
|
||||
|
||||
## CogVideoXSpatialNorm3D
|
||||
|
||||
[[autodoc]] CogVideoXSpatialNorm3D
|
||||
|
||||
## CogVideoXResnetBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXResnetBlock3D
|
||||
|
||||
## CogVideoXDownBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXDownBlock3D
|
||||
|
||||
## CogVideoXMidBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXMidBlock3D
|
||||
|
||||
## CogVideoXUpBlock3D
|
||||
|
||||
[[autodoc]] CogVideoXUpBlock3D
|
||||
|
||||
## CogVideoXEncoder3D
|
||||
|
||||
[[autodoc]] CogVideoXEncoder3D
|
||||
|
||||
## CogVideoXDecoder3D
|
||||
|
||||
[[autodoc]] CogVideoXDecoder3D
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
|
||||
@@ -9,10 +9,22 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
## CogVideoXTransformer3DModel
|
||||
# CogVideoXTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX).
|
||||
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
|
||||
vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## CogVideoXTransformer3DModel
|
||||
|
||||
[[autodoc]] CogVideoXTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# StableCascadeUNet
|
||||
|
||||
A UNet model from the [Stable Cascade pipeline](../pipelines/stable_cascade.md).
|
||||
|
||||
## StableCascadeUNet
|
||||
|
||||
[[autodoc]] models.unets.unet_stable_cascade.StableCascadeUNet
|
||||
@@ -10,18 +10,16 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
## TODO: The paper is still being written.
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# CogVideoX
|
||||
|
||||
[TODO]() from Tsinghua University & ZhipuAI.
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
The paper is still being written.
|
||||
*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.*
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -29,7 +27,13 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
### Inference
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
There are two models available that can be used with the CogVideoX pipeline:
|
||||
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
|
||||
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
|
||||
## Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
@@ -37,38 +41,46 @@ First, load the pipeline:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LattePipeline
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipeline = LattePipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
|
||||
|
||||
```python
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
# CogVideoX works very well with long and well-described prompts
|
||||
# CogVideoX works well with long and well-described prompts
|
||||
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
|
||||
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
```
|
||||
|
||||
The [benchmark](TODO: link) results on an 80GB A100 machine are:
|
||||
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
|
||||
|
||||
```
|
||||
Without torch.compile(): Average inference time: TODO seconds.
|
||||
With torch.compile(): Average inference time: TODO seconds.
|
||||
Without torch.compile(): Average inference time: 96.89 seconds.
|
||||
With torch.compile(): Average inference time: 76.27 seconds.
|
||||
```
|
||||
|
||||
### Memory optimization
|
||||
|
||||
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
|
||||
|
||||
- `pipe.enable_model_cpu_offload()`:
|
||||
- Without enabling cpu offloading, memory usage is `33 GB`
|
||||
- With enabling cpu offloading, memory usage is `19 GB`
|
||||
- `pipe.vae.enable_tiling()`:
|
||||
- With enabling cpu offloading and tiling, memory usage is `11 GB`
|
||||
- `pipe.vae.enable_slicing()`
|
||||
|
||||
## CogVideoXPipeline
|
||||
|
||||
[[autodoc]] CogVideoXPipeline
|
||||
@@ -76,4 +88,5 @@ With torch.compile(): Average inference time: TODO seconds.
|
||||
- __call__
|
||||
|
||||
## CogVideoXPipelineOutput
|
||||
[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
|
||||
|
||||
@@ -37,7 +37,7 @@ Both checkpoints have slightly difference usage which we detail below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -61,7 +61,7 @@ out.save("image.png")
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -77,8 +77,89 @@ out = pipe(
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
## Running FP16 inference
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
FP16 inference code:
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
|
||||
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
out = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.,
|
||||
height=768,
|
||||
width=1360,
|
||||
num_inference_steps=4,
|
||||
max_sequence_length=256,
|
||||
).images[0]
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
## Single File Loading for the `FluxTransformer2DModel`
|
||||
|
||||
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
<Tip>
|
||||
`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
|
||||
</Tip>
|
||||
|
||||
The following example demonstrates how to run Flux with less than 16GB of VRAM.
|
||||
|
||||
First install `optimum-quanto`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto
|
||||
```
|
||||
|
||||
Then run the following example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, FluxPipeline
|
||||
from transformers import T5EncoderModel, CLIPTextModel
|
||||
from optimum.quanto import freeze, qfloat8, quantize
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
|
||||
quantize(transformer, weights=qfloat8)
|
||||
freeze(transformer)
|
||||
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=3.5,
|
||||
output_type="pil",
|
||||
num_inference_steps=20,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
image.save("flux-fp8-dev.png")
|
||||
```
|
||||
|
||||
## FluxPipeline
|
||||
|
||||
[[autodoc]] FluxPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -43,6 +43,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KolorsPAGPipeline
|
||||
[[autodoc]] KolorsPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
- all
|
||||
@@ -74,6 +79,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- __call__
|
||||
|
||||
|
||||
## StableDiffusion3PAGPipeline
|
||||
[[autodoc]] StableDiffusion3PAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## PixArtSigmaPAGPipeline
|
||||
[[autodoc]] PixArtSigmaPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Community Projects
|
||||
|
||||
Welcome to Community Projects. This space is dedicated to showcasing the incredible work and innovative applications created by our vibrant community using the `diffusers` library.
|
||||
|
||||
This section aims to:
|
||||
|
||||
- Highlight diverse and inspiring projects built with `diffusers`
|
||||
- Foster knowledge sharing within our community
|
||||
- Provide real-world examples of how `diffusers` can be leveraged
|
||||
|
||||
Happy exploring, and thank you for being part of the Diffusers community!
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Project Name</th>
|
||||
<th>Description</th>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
|
||||
<td>Stable Diffusion built-in to Blender</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/megvii-research/HiDiffusion"> HiDiffusion </a></td>
|
||||
<td>Increases the resolution and speed of your diffusion model by only adding a single line of code</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/lllyasviel/IC-Light"> IC-Light </a></td>
|
||||
<td>IC-Light is a project to manipulate the illumination of images</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/InstantID/InstantID"> InstantID </a></td>
|
||||
<td>InstantID : Zero-shot Identity-Preserving Generation in Seconds</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/Sanster/IOPaint"> IOPaint </a></td>
|
||||
<td>Image inpainting tool powered by SOTA AI Model. Remove any unwanted object, defect, people from your pictures or erase and replace(powered by stable diffusion) any thing on your pictures.</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/bmaltais/kohya_ss"> Kohya </a></td>
|
||||
<td>Gradio GUI for Kohya's Stable Diffusion trainers</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/magic-research/magic-animate"> MagicAnimate </a></td>
|
||||
<td>MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/levihsu/OOTDiffusion"> OOTDiffusion </a></td>
|
||||
<td>Outfitting Fusion based Latent Diffusion for Controllable Virtual Try-on</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/vladmandic/automatic"> SD.Next </a></td>
|
||||
<td>SD.Next: Advanced Implementation of Stable Diffusion and other Diffusion-based generative image models</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/ashawkey/stable-dreamfusion"> stable-dreamfusion </a></td>
|
||||
<td>Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/HVision-NKU/StoryDiffusion"> StoryDiffusion </a></td>
|
||||
<td>StoryDiffusion can create a magic story by generating consistent images and videos.</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/cumulo-autumn/StreamDiffusion"> StreamDiffusion </a></td>
|
||||
<td>A Pipeline-Level Solution for Real-Time Interactive Generation</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be "turn the clouds rainy" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image.
|
||||
|
||||
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
|
||||
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use case.
|
||||
|
||||
Before running the script, make sure you install the library from source:
|
||||
|
||||
@@ -117,7 +117,7 @@ optimizer = optimizer_cls(
|
||||
)
|
||||
```
|
||||
|
||||
Next, the edited images and and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
|
||||
Next, the edited images and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
|
||||
|
||||
```py
|
||||
def preprocess_train(examples):
|
||||
@@ -249,4 +249,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
|
||||
|
||||
Congratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to:
|
||||
|
||||
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
|
||||
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
|
||||
|
||||
@@ -34,7 +34,7 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which lets you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Pipeline callbacks
|
||||
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
|
||||
> [!TIP]
|
||||
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
@@ -75,7 +75,7 @@ out.images[0].save("official_callback.png")
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">without SDXLCFGCutoffCallback</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a a sports car at the road with cfg callback" />
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a sports car at the road with cfg callback" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">with SDXLCFGCutoffCallback</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche
|
||||
3. Load an image processor:
|
||||
|
||||
```python
|
||||
from transformers import CLIPFeatureExtractor
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
||||
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from scheduling_tcd import TCDScheduler
|
||||
|
||||
device = "cuda"
|
||||
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
|
||||
def get_depth_map(image):
|
||||
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
|
||||
|
||||
@@ -307,7 +307,7 @@ print(pipeline)
|
||||
|
||||
위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.
|
||||
|
||||
- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스
|
||||
- `"feature_extractor"`: [`~transformers.CLIPImageProcessor`]의 인스턴스
|
||||
- `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)
|
||||
- `"scheduler"`: [`PNDMScheduler`]의 인스턴스
|
||||
- `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스
|
||||
|
||||
@@ -24,7 +24,7 @@ import PIL
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
|
||||
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1435,9 +1435,9 @@ import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
@@ -2122,7 +2122,7 @@ import torch
|
||||
import open_clip
|
||||
from open_clip import SimpleTokenizer
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
|
||||
|
||||
def download_image(url):
|
||||
@@ -2130,7 +2130,7 @@ def download_image(url):
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
# Loading additional models
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
|
||||
@@ -7,7 +7,7 @@ import PIL.Image
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -86,7 +86,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMi
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
coca_model=None,
|
||||
coca_tokenizer=None,
|
||||
coca_transform=None,
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -32,9 +32,9 @@ EXAMPLE_DOC_STRING = """
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
)
|
||||
clip_model = CLIPModel.from_pretrained(
|
||||
@@ -139,7 +139,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -43,8 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
check_min_version("0.30.0")
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from numpy import exp, pi, sqrt
|
||||
from torchvision.transforms.functional import resize
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
@@ -275,7 +275,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -15,7 +15,7 @@ from diffusers.utils import logging
|
||||
|
||||
try:
|
||||
from ligo.segments import segment
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
except ImportError:
|
||||
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
|
||||
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
|
||||
Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
# from ...configuration_utils import FrozenDict
|
||||
# from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
cc_projection ([`CCProjection`]):
|
||||
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
|
||||
@@ -102,7 +102,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
cc_projection: CCProjection,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
@@ -69,7 +69,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
@@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline(
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -100,7 +100,7 @@ class StableDiffusionIPEXPipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -693,7 +693,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -683,7 +683,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -697,7 +697,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae", "vae_encoder"],
|
||||
|
||||
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
@@ -595,7 +595,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -609,7 +609,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1271,7 +1271,7 @@ def main(args):
|
||||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
|
||||
}
|
||||
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from PIL import Image
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig
|
||||
from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig
|
||||
from webdataset.tariterators import (
|
||||
base_plus_ext,
|
||||
tar_file_expander,
|
||||
@@ -205,7 +205,7 @@ class Text2ImageDataset:
|
||||
pin_memory: bool = False,
|
||||
persistent_workers: bool = False,
|
||||
control_type: str = "canny",
|
||||
feature_extractor: Optional[DPTFeatureExtractor] = None,
|
||||
feature_extractor: Optional[DPTImageProcessor] = None,
|
||||
):
|
||||
if not isinstance(train_shards_path_or_url, str):
|
||||
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
|
||||
@@ -1011,7 +1011,7 @@ def main(args):
|
||||
controlnet = pre_controlnet
|
||||
|
||||
if args.control_type == "depth":
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
depth_model.requires_grad_(False)
|
||||
else:
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
" UniPCMultistepScheduler,\n",
|
||||
" EulerDiscreteScheduler,\n",
|
||||
")\n",
|
||||
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
|
||||
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
|
||||
"\n",
|
||||
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
from PIL import Image
|
||||
from retriever import Retriever, normalize_images, preprocess_images
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@@ -65,7 +65,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
retriever: Optional[Retriever] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig
|
||||
from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
|
||||
|
||||
from diffusers import logging
|
||||
|
||||
@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
|
||||
return images
|
||||
|
||||
|
||||
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
|
||||
def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
|
||||
"""
|
||||
Preprocesses a list of images into a batch of tensors.
|
||||
|
||||
@@ -95,14 +95,12 @@ class Index:
|
||||
def build_index(
|
||||
self,
|
||||
model=None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
torch_dtype=torch.float32,
|
||||
):
|
||||
if not self.index_initialized:
|
||||
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
|
||||
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
|
||||
self.config.clip_name_or_path
|
||||
)
|
||||
feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
|
||||
self.dataset = get_dataset_with_emb_from_clip_model(
|
||||
self.dataset,
|
||||
model,
|
||||
@@ -136,7 +134,7 @@ class Retriever:
|
||||
index: Index = None,
|
||||
dataset: Dataset = None,
|
||||
model=None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
self.config = config
|
||||
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
|
||||
@@ -148,7 +146,7 @@ class Retriever:
|
||||
index: Index = None,
|
||||
dataset: Dataset = None,
|
||||
model=None,
|
||||
feature_extractor: CLIPFeatureExtractor = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
**kwargs,
|
||||
):
|
||||
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
@@ -156,7 +154,7 @@ class Retriever:
|
||||
|
||||
@staticmethod
|
||||
def _build_index(
|
||||
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
|
||||
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
|
||||
):
|
||||
dataset = dataset or load_dataset(config.dataset_name)
|
||||
dataset = dataset[config.dataset_set]
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
@@ -69,7 +69,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed):
|
||||
# to the function and tell JAX which are static arguments, that is, arguments that
|
||||
# are known at compile time and won't change. In our case, it is num_inference_steps,
|
||||
# height, width and return_latents.
|
||||
# Once the function is compiled, these parameters are ommited from future calls and
|
||||
# Once the function is compiled, these parameters are omitted from future calls and
|
||||
# cannot be changed without modifying the code and recompiling.
|
||||
def aot_compile(
|
||||
prompt=default_prompt,
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
check_min_version("0.30.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -86,6 +86,9 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||
"embed_tokens": remove_keys_inplace,
|
||||
"freqs_sin": remove_keys_inplace,
|
||||
"freqs_cos": remove_keys_inplace,
|
||||
"position_embedding": remove_keys_inplace,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(ckpt_path: str):
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
transformer = CogVideoXTransformer3DModel()
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str):
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
vae = AutoencoderKLCogVideoX()
|
||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -172,13 +185,26 @@ def get_args():
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||
parser.add_argument(
|
||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||
)
|
||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -188,18 +214,33 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
vae = None
|
||||
|
||||
if args.fp16 and args.bf16:
|
||||
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
||||
|
||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
transformer = convert_transformer(
|
||||
args.transformer_ckpt_path,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.use_rotary_positional_embeddings,
|
||||
dtype,
|
||||
)
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path)
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work any more without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||
{
|
||||
"snr_shift_scale": 3.0,
|
||||
"snr_shift_scale": args.snr_shift_scale,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
@@ -208,7 +249,7 @@ if __name__ == "__main__":
|
||||
"prediction_type": "v_prediction",
|
||||
"rescale_betas_zero_snr": True,
|
||||
"set_alpha_to_one": True,
|
||||
"timestep_spacing": "linspace",
|
||||
"timestep_spacing": "trailing",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -218,5 +259,10 @@ if __name__ == "__main__":
|
||||
|
||||
if args.fp16:
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
if args.bf16:
|
||||
pipe = pipe.to(dtype=torch.bfloat16)
|
||||
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.30.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.30.0.dev0"
|
||||
__version__ = "0.30.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -12,6 +12,7 @@ from .utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
@@ -250,8 +251,6 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
@@ -286,8 +285,6 @@ else:
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"KolorsPipeline",
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LattePipeline",
|
||||
@@ -314,6 +311,7 @@ else:
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
@@ -391,6 +389,19 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -679,8 +690,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
CLIPImageProjection,
|
||||
CogVideoXPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
@@ -715,8 +724,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LattePipeline,
|
||||
@@ -743,6 +750,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
@@ -814,6 +822,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -208,6 +208,8 @@ class IPAdapterMixin:
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
|
||||
@@ -1489,10 +1489,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
return_alphas: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1577,7 +1577,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
return state_dict
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
network_alphas = {}
|
||||
for k in keys:
|
||||
if "alpha" in k:
|
||||
alpha_value = state_dict.get(k)
|
||||
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
||||
alpha_value, float
|
||||
):
|
||||
network_alphas[k] = state_dict.pop(k)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
||||
)
|
||||
|
||||
if return_alphas:
|
||||
return state_dict, network_alphas
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
@@ -1611,7 +1630,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
||||
)
|
||||
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
@@ -1619,6 +1640,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
@@ -1628,7 +1650,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=None,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
@@ -1637,8 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1647,6 +1668,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
transformer (`SD3Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
@@ -1678,7 +1703,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
prefix = cls.transformer_name
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
|
||||
@@ -23,6 +23,7 @@ from packaging import version
|
||||
from ..utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and is_legacy_loading:
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
@@ -74,6 +75,13 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,9 +74,15 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||
"flux": [
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -85,11 +91,11 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
|
||||
"playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
|
||||
"upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
|
||||
"inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"},
|
||||
"inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
|
||||
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
|
||||
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
|
||||
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
|
||||
"v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"},
|
||||
"v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
|
||||
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
|
||||
"stable_cascade_stage_b_lite": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-cascade",
|
||||
@@ -110,6 +116,10 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -251,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
@@ -260,8 +270,8 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||
"cond_stage_model.transformer.",
|
||||
"conditioner.embedders.0.transformer.",
|
||||
]
|
||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
|
||||
|
||||
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
|
||||
@@ -311,6 +321,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
|
||||
return weights_exist
|
||||
|
||||
|
||||
def _is_legacy_scheduler_kwargs(kwargs):
|
||||
return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
|
||||
|
||||
|
||||
def load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=False,
|
||||
@@ -491,7 +505,13 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "sd3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
|
||||
model_type = "animatediff_scribble"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
|
||||
model_type = "animatediff_rgb"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
model_type = "animatediff_v2"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
|
||||
@@ -503,6 +523,13 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1158,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
|
||||
vae_key = ""
|
||||
for ldm_vae_key in LDM_VAE_KEYS:
|
||||
if any(k.startswith(ldm_vae_key) for k in keys):
|
||||
vae_key = ldm_vae_key
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
@@ -1459,14 +1490,22 @@ def _legacy_load_scheduler(
|
||||
|
||||
if scheduler_type is not None:
|
||||
deprecation_message = (
|
||||
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
|
||||
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
|
||||
"Example:\n\n"
|
||||
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
||||
"scheduler = DDIMScheduler()\n"
|
||||
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
||||
)
|
||||
deprecate("scheduler_type", "1.0.0", deprecation_message)
|
||||
|
||||
if prediction_type is not None:
|
||||
deprecation_message = (
|
||||
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
|
||||
"and pass the object directly to the `scheduler` argument in `from_single_file`."
|
||||
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
|
||||
"pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
|
||||
"Example:\n\n"
|
||||
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
||||
'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
|
||||
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
||||
)
|
||||
deprecate("prediction_type", "1.0.0", deprecation_message)
|
||||
|
||||
@@ -1859,3 +1898,199 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
mlp_ratio = 4.0
|
||||
inner_dim = 3072
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"time_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"time_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"vector_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
|
||||
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in checkpoint)
|
||||
if has_guidance:
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"guidance_in.in_layer.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"guidance_in.out_layer.bias"
|
||||
)
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
||||
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
||||
|
||||
# x_embedder
|
||||
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
||||
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
# norms.
|
||||
## norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.bias"
|
||||
)
|
||||
## norm1_context
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.bias"
|
||||
)
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
# qk_norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||
)
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.bias"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.bias"
|
||||
)
|
||||
|
||||
# single transfomer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.modulation.lin.bias"
|
||||
)
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
||||
# qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
||||
f"single_blocks.{i}.norm.key_norm.scale"
|
||||
)
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
||||
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
||||
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
@@ -376,7 +387,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
"layer_norm",
|
||||
)
|
||||
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
|
||||
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
@@ -782,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FreeNoiseTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A FreeNoise Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (`int`, *optional*):
|
||||
The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, defaults to `False`):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, defaults to `False`):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
ff_inner_dim (`int`, *optional*):
|
||||
Hidden dimension of feed-forward MLP.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in feed-forward MLP.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in attention output project layer.
|
||||
context_length (`int`, defaults to `16`):
|
||||
The maximum number of frames that the FreeNoise block processes at once.
|
||||
context_stride (`int`, defaults to `4`):
|
||||
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
||||
weighting_scheme (`str`, defaults to `"pyramid"`):
|
||||
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
||||
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
|
||||
used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
context_length: int = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.double_self_attention = double_self_attention
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
||||
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
||||
frame_indices = []
|
||||
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
||||
window_start = i
|
||||
window_end = min(num_frames, i + self.context_length)
|
||||
frame_indices.append((window_start, window_end))
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
return weights
|
||||
|
||||
def set_free_noise_properties(
|
||||
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
||||
) -> None:
|
||||
self.context_length = context_length
|
||||
self.context_stride = context_stride
|
||||
self.weighting_scheme = weighting_scheme
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
# hidden_states: [B x H x W, F, C]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
num_frames = hidden_states.size(1)
|
||||
frame_indices = self._get_frame_indices(num_frames)
|
||||
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
||||
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
||||
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
||||
|
||||
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
||||
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
||||
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
||||
if not is_last_frame_batch_complete:
|
||||
if num_frames < self.context_length:
|
||||
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
||||
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
||||
frame_indices.append((num_frames - self.context_length, num_frames))
|
||||
|
||||
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
||||
accumulated_values = torch.zeros_like(hidden_states)
|
||||
|
||||
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
||||
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
||||
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
||||
# essentially a non-multiple of `context_length`.
|
||||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
||||
weights *= frame_weights
|
||||
|
||||
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
||||
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
if hidden_states_chunk.ndim == 4:
|
||||
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = self.norm2(hidden_states_chunk)
|
||||
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states_chunk = attn_output + hidden_states_chunk
|
||||
|
||||
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
||||
accumulated_values[:, -last_frame_batch_length:] += (
|
||||
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
||||
)
|
||||
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
||||
else:
|
||||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
||||
num_times_accumulated[:, frame_start:frame_end] += weights
|
||||
|
||||
hidden_states = torch.where(
|
||||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
||||
).to(dtype)
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
@@ -227,6 +227,7 @@ class Attention(nn.Module):
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
@@ -698,12 +699,15 @@ class Attention(nn.Module):
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
@@ -1102,6 +1106,326 @@ class JointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# store the length of image patch sequences to create a mask that prevents interaction between patches
|
||||
# similar to making the self-attention map an identity matrix
|
||||
identity_block_size = hidden_states.shape[1]
|
||||
|
||||
# chunk
|
||||
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
||||
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class PAGCFGJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
identity_block_size = hidden_states.shape[
|
||||
1
|
||||
] # patch embeddings width * height (correspond to self-attention map width or height)
|
||||
|
||||
# chunk
|
||||
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
||||
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
||||
|
||||
(
|
||||
encoder_hidden_states_uncond,
|
||||
encoder_hidden_states_org,
|
||||
encoder_hidden_states_ptb,
|
||||
) = encoder_hidden_states.chunk(3)
|
||||
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
|
||||
|
||||
################## original path ##################
|
||||
batch_size = encoder_hidden_states_org.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_org = attn.to_q(hidden_states_org)
|
||||
key_org = attn.to_k(hidden_states_org)
|
||||
value_org = attn.to_v(hidden_states_org)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
||||
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
||||
|
||||
# attention
|
||||
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
||||
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
||||
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_org.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states_org, encoder_hidden_states_org = (
|
||||
hidden_states_org[:, : residual.shape[1]],
|
||||
hidden_states_org[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################## perturbed path ##################
|
||||
|
||||
batch_size = encoder_hidden_states_ptb.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query_ptb = attn.to_q(hidden_states_ptb)
|
||||
key_ptb = attn.to_k(hidden_states_ptb)
|
||||
value_ptb = attn.to_v(hidden_states_ptb)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
||||
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
||||
|
||||
# attention
|
||||
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
||||
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
||||
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
||||
|
||||
inner_dim = key_ptb.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# create a full mask with all entries set to 0
|
||||
seq_len = query_ptb.size(2)
|
||||
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
||||
|
||||
# set the attention value between image patches to -inf
|
||||
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
||||
|
||||
# set the diagonal of the attention value between image patches to 0
|
||||
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
||||
|
||||
# expand the mask to match the attention weights shape
|
||||
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
||||
|
||||
hidden_states_ptb = F.scaled_dot_product_attention(
|
||||
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
||||
|
||||
# split the attention outputs.
|
||||
hidden_states_ptb, encoder_hidden_states_ptb = (
|
||||
hidden_states_ptb[:, : residual.shape[1]],
|
||||
hidden_states_ptb[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
|
||||
################ concat ###############
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -1274,6 +1598,103 @@ class AuraFlowAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAuraFlowAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
@@ -1447,6 +1868,148 @@ class FluxAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -7,6 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..downsampling import CogVideoXDownsample3D
|
||||
@@ -16,8 +32,11 @@ from ..upsampling import CogVideoXUpsample3D
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
"""
|
||||
r"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
||||
"""
|
||||
|
||||
@@ -49,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input tensor.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
|
||||
stride (int, optional): Stride of the convolution. Default is 1.
|
||||
dilation (int, optional): Dilation rate of the convolution. Default is 1.
|
||||
pad_mode (str, optional): Padding mode. Default is "constant".
|
||||
in_channels (`int`): Number of channels in the input tensor.
|
||||
out_channels (`int`): Number of output channels produced by the convolution.
|
||||
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
stride (`int`, defaults to `1`): Stride of the convolution.
|
||||
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
||||
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -98,35 +117,31 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
|
||||
self.conv_cache = None
|
||||
|
||||
def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
dim = self.temporal_dim
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size == 1:
|
||||
return inputs
|
||||
|
||||
inputs = inputs.transpose(0, dim)
|
||||
|
||||
if self.conv_cache is not None:
|
||||
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
|
||||
else:
|
||||
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
|
||||
|
||||
inputs = inputs.transpose(0, dim).contiguous()
|
||||
if kernel_size > 1:
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True):
|
||||
input_parallel = self.fake_cp_pass_from_previous_rank(inputs)
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
del self.conv_cache
|
||||
self.conv_cache = None
|
||||
if not clear_fake_cp_cache:
|
||||
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
output = self.conv(inputs)
|
||||
return output
|
||||
|
||||
|
||||
@@ -142,15 +157,18 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
||||
zq_channels (`int`):
|
||||
The number of channels for the quantized vector as described in the paper.
|
||||
groups (`int`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
@@ -175,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
A 3D ResNet block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (Optional[int], optional):
|
||||
Number of output channels. If None, defaults to `in_channels`. Default is None.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
temb_channels (int, optional): Number of time embedding channels. Default is 512.
|
||||
groups (int, optional): Number of groups for group normalization. Default is 32.
|
||||
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
|
||||
non_linearity (str, optional): Activation function to use. Default is "swish".
|
||||
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
non_linearity (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
conv_shortcut (bool, defaults to `False`):
|
||||
Whether or not to use a convolution shortcut.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -217,10 +244,12 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
self.norm1 = CogVideoXSpatialNorm3D(
|
||||
f_channels=in_channels,
|
||||
zq_channels=spatial_norm_dim,
|
||||
groups=groups,
|
||||
)
|
||||
self.norm2 = CogVideoXSpatialNorm3D(
|
||||
f_channels=out_channels,
|
||||
zq_channels=spatial_norm_dim,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
self.conv1 = CogVideoXCausalConv3d(
|
||||
@@ -250,15 +279,16 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = True,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs
|
||||
|
||||
if zq is not None:
|
||||
hidden_states = self.norm1(hidden_states, zq)
|
||||
else:
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
@@ -270,16 +300,13 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
else:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
|
||||
output_tensor = inputs + hidden_states
|
||||
return output_tensor
|
||||
hidden_states = hidden_states + inputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXDownBlock3D(nn.Module):
|
||||
@@ -287,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
A downsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
|
||||
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
add_downsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -355,7 +392,6 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -367,10 +403,10 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
@@ -384,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
A middle block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -435,7 +480,6 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -447,10 +491,10 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -460,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
An upsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
|
||||
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
|
||||
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, defaults to `16`):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
add_upsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -522,12 +577,13 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
clear_fake_cp_cache: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
for resnet in self.resnets:
|
||||
@@ -540,10 +596,10 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -566,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -651,11 +705,9 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -668,25 +720,25 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache
|
||||
create_custom_forward(down_block), hidden_states, temb, None
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache)
|
||||
hidden_states = down_block(hidden_states, temb, None)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache)
|
||||
hidden_states = self.mid_block(hidden_states, temb, None)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -704,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
norm_type (`str`, *optional*, defaults to `"group"`):
|
||||
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -788,7 +838,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels)
|
||||
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CogVideoXCausalConv3d(
|
||||
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
|
||||
@@ -796,11 +846,9 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -812,32 +860,33 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
# 1. Mid
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache
|
||||
create_custom_forward(up_block), hidden_states, temb, sample
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache)
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample)
|
||||
|
||||
# 2. Up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache)
|
||||
hidden_states = up_block(hidden_states, temb, sample)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encodfing images into latents and decoding latent representations into images.
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
|
||||
[CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
@@ -853,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
@@ -864,9 +913,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
||||
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
||||
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
mid_block_add_attention (`bool`, *optional*, default to `True`):
|
||||
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
|
||||
mid_block will only have resnet blocks
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -896,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_num_groups: int = 32,
|
||||
temporal_compression_ratio: float = 4,
|
||||
sample_size: int = 256,
|
||||
sample_height: int = 480,
|
||||
sample_width: int = 720,
|
||||
scaling_factor: float = 1.15258426,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
@@ -904,7 +951,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
mid_block_add_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -936,22 +982,108 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||
# If you decode X latent frames together, the number of output frames is:
|
||||
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
||||
#
|
||||
# Example with num_latent_frames_batch_size = 2:
|
||||
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
||||
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 6 * 8 = 48 frames
|
||||
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
||||
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
||||
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 1 * 9 + 5 * 8 = 49 frames
|
||||
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
||||
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
||||
# number of temporal frames.
|
||||
self.num_latent_frames_batch_size = 2
|
||||
|
||||
# We make the minimum height and width of sample for tiling half that of the generally supported
|
||||
self.tile_sample_min_height = sample_height // 2
|
||||
self.tile_sample_min_width = sample_width // 2
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
|
||||
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
||||
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
||||
# and so the tiling implementation has only been tested on those specific resolutions.
|
||||
self.tile_overlap_factor_height = 1 / 6
|
||||
self.tile_overlap_factor_width = 1 / 5
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_overlap_factor_height: Optional[float] = None,
|
||||
tile_overlap_factor_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_overlap_factor_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
tile_overlap_factor_width (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
||||
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
@@ -960,14 +1092,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
fake_cp (`bool`, *optional*, defaults to `True`):
|
||||
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self.encoder(x, clear_fake_cp_cache=not fake_cp)
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
@@ -975,10 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
dec = []
|
||||
for i in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -986,20 +1140,116 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
fake_cp (`bool`, *optional*, defaults to `True`):
|
||||
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, clear_fake_cp_cache=not fake_cp)
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# Rough memory assessment:
|
||||
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
||||
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
||||
# - Assume fp16 (2 bytes per value).
|
||||
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
||||
#
|
||||
# Memory assessment when using tiling:
|
||||
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
||||
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.decoder(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -20,6 +20,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -92,7 +93,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin):
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
|
||||
Models](https://arxiv.org/abs/2311.16933).
|
||||
@@ -314,6 +315,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
elif down_block_type == "DownBlockMotion":
|
||||
down_block = DownBlockMotion(
|
||||
@@ -331,6 +333,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
|
||||
temporal_num_attention_heads=motion_num_attention_heads[i],
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -285,7 +285,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
upcast_attention (`bool`, defaults to `True`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
max_norm_num_groups (`int`, defaults to 32):
|
||||
Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
|
||||
Maximum number of groups in group normal. The actual number will be the largest divisor of the respective
|
||||
channels, that is <= max_norm_num_groups.
|
||||
"""
|
||||
|
||||
|
||||
@@ -374,6 +374,90 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
return embeds
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
@@ -773,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||
model_file if not is_sharded else index_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
@@ -803,7 +803,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
||||
model_file if not is_sharded else index_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
|
||||
@@ -34,7 +34,11 @@ class AdaLayerNorm(nn.Module):
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
|
||||
output_dim (`int`, *optional*):
|
||||
norm_elementwise_affine (`bool`, defaults to `False):
|
||||
norm_eps (`bool`, defaults to `False`):
|
||||
chunk_dim (`int`, defaults to `0`):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,14 +53,13 @@ class AdaLayerNorm(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
output_dim = output_dim or embedding_dim * 2
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
@@ -68,7 +71,10 @@ class AdaLayerNorm(nn.Module):
|
||||
temb = self.emb(timestep)
|
||||
|
||||
temb = self.linear(self.silu(temb))
|
||||
|
||||
if self.chunk_dim == 1:
|
||||
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
|
||||
# other if-branch. This branch is specific to CogVideoX for now.
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
|
||||
@@ -22,7 +22,12 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -63,6 +68,21 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
self.height, self.width = height // patch_size, width // patch_size
|
||||
self.base_size = height // patch_size
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
# select subset of positional embedding based on H, W, where H, W is size of latent
|
||||
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
||||
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
|
||||
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
||||
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
|
||||
starth = h_max // 2 - h_p // 2
|
||||
endh = starth + h_p
|
||||
startw = w_max // 2 - w_p // 2
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
@@ -75,7 +95,8 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
)
|
||||
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
latent = self.proj(latent)
|
||||
return latent + self.pos_embed
|
||||
pe_index = self.pe_selection_index_based_on_dim(height, width)
|
||||
return latent + self.pos_embed[:, pe_index]
|
||||
|
||||
|
||||
# Taken from the original Aura flow inference code.
|
||||
@@ -320,6 +341,106 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -34,37 +35,37 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
Epsilon value for normalization layers.
|
||||
final_dropout (`bool` defaults to `False`):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
||||
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
||||
ff_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Feed-forward layer.
|
||||
attention_out_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in Attention output projection layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -97,6 +98,7 @@ class CogVideoXBlock(nn.Module):
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
@@ -116,24 +118,24 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
text_length = norm_encoder_hidden_states.size(1)
|
||||
|
||||
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
|
||||
# them in cross-attention individually
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attn_output = self.attn1(
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
@@ -144,46 +146,64 @@ class CogVideoXBlock(nn.Module):
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*):
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*):
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
||||
`num_embeds_ada_norm`.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
||||
caption_channels (`int`, *optional*):
|
||||
The number of channels in the caption embeddings.
|
||||
video_length (`int`, *optional*):
|
||||
The number of frames in the video-like data.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -193,7 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: Optional[int] = 16,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
@@ -214,6 +234,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -278,12 +299,113 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
@@ -302,16 +424,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Transformer blocks
|
||||
# 4. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -327,6 +451,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -334,15 +459,23 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 6. Final block
|
||||
# 5. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 7. Unpatchify
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -247,6 +247,46 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -125,6 +125,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
gate = gate.unsqueeze(1)
|
||||
hidden_states = gate * self.proj_out(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -223,11 +225,13 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
|
||||
@@ -233,6 +233,7 @@ class DownBlockMotion(nn.Module):
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -282,6 +283,7 @@ class DownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads[i],
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -343,6 +345,7 @@ class DownBlockMotion(nn.Module):
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
@@ -384,6 +387,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -465,6 +469,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads,
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -536,6 +541,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -761,6 +767,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -921,9 +928,9 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
@@ -1923,7 +1930,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -1953,7 +1959,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
@@ -146,7 +147,9 @@ else:
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"AnimateDiffPAGPipeline",
|
||||
"KolorsPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
@@ -206,12 +209,6 @@ else:
|
||||
"Kandinsky3Img2ImgPipeline",
|
||||
"Kandinsky3Pipeline",
|
||||
]
|
||||
_import_structure["kolors"] = [
|
||||
"KolorsPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
]
|
||||
_import_structure["latent_consistency_models"] = [
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
@@ -351,6 +348,22 @@ else:
|
||||
"StableDiffusionKDiffusionPipeline",
|
||||
"StableDiffusionXLKDiffusionPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import (
|
||||
dummy_torch_and_transformers_and_sentencepiece_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
|
||||
else:
|
||||
_import_structure["kolors"] = [
|
||||
"KolorsPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -509,12 +522,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3Img2ImgPipeline,
|
||||
Kandinsky3Pipeline,
|
||||
)
|
||||
from .kolors import (
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
)
|
||||
from .latent_consistency_models import (
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
@@ -536,7 +543,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pag import (
|
||||
AnimateDiffPAGPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
KolorsPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -644,6 +653,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLKDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import *
|
||||
else:
|
||||
from .kolors import (
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -42,6 +42,7 @@ from ...utils import (
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -394,15 +396,20 @@ class AnimateDiffPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -495,10 +502,21 @@ class AnimateDiffPipeline(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -506,11 +524,6 @@ class AnimateDiffPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -569,6 +582,7 @@ class AnimateDiffPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -637,6 +651,8 @@ class AnimateDiffPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
decode_chunk_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -808,7 +824,7 @@ class AnimateDiffPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -30,6 +30,7 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..controlnet.multicontrolnet import MultiControlNetModel
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation with ControlNet guidance.
|
||||
@@ -432,15 +434,16 @@ class AnimateDiffControlNetPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def decode_latents(self, latents, decode_batch_size: int = 16):
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_batch_size):
|
||||
batch_latents = latents[i : i + decode_batch_size]
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
@@ -608,10 +611,22 @@ class AnimateDiffControlNetPipeline(
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
|
||||
if self.free_noise_enabled:
|
||||
latents = self._prepare_latents_free_noise(
|
||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -619,11 +634,6 @@ class AnimateDiffControlNetPipeline(
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
@@ -718,7 +728,7 @@ class AnimateDiffControlNetPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_batch_size: int = 16,
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -1054,7 +1064,7 @@ class AnimateDiffControlNetPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_batch_size)
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -35,6 +35,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
@@ -498,15 +500,29 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor:
|
||||
latents = []
|
||||
for i in range(0, len(video), decode_chunk_size):
|
||||
batch_video = video[i : i + decode_chunk_size]
|
||||
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
|
||||
latents.append(batch_video)
|
||||
return torch.cat(latents)
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
|
||||
def decode_latents(self, latents, decode_chunk_size: int = 16):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
video = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
batch_latents = latents[i : i + decode_chunk_size]
|
||||
batch_latents = self.vae.decode(batch_latents).sample
|
||||
video.append(batch_latents)
|
||||
|
||||
video = torch.cat(video)
|
||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
@@ -622,6 +638,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
if latents is None:
|
||||
num_frames = video.shape[1]
|
||||
@@ -656,13 +673,11 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
||||
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
||||
]
|
||||
init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
@@ -747,6 +762,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -822,6 +838,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
decode_chunk_size (`int`, defaults to `16`):
|
||||
The number of frames to decode at a time when calling `decode_latents` method.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -923,6 +941,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@@ -990,7 +1009,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
@@ -47,11 +48,11 @@ from .kandinsky2_2 import (
|
||||
KandinskyV22Pipeline,
|
||||
)
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -84,6 +85,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion", StableDiffusionPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3Pipeline),
|
||||
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
|
||||
("if", IFPipeline),
|
||||
("hunyuan", HunyuanDiTPipeline),
|
||||
("hunyuan-pag", HunyuanDiTPAGPipeline),
|
||||
@@ -103,7 +105,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
|
||||
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("kolors", KolorsPipeline),
|
||||
("flux", FluxPipeline),
|
||||
]
|
||||
)
|
||||
@@ -121,7 +122,6 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("kolors", KolorsImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -160,6 +160,14 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .kolors import KolorsPipeline
|
||||
from .pag import KolorsPAGPipeline
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
|
||||
SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
@@ -36,10 +37,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
@@ -48,14 +51,31 @@ EXAMPLE_DOC_STRING = """
|
||||
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
... "atmosphere of this unique musical performance."
|
||||
... )
|
||||
>>> video = pipe(
|
||||
... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20
|
||||
... ).frames[0]
|
||||
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -154,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
@@ -182,9 +202,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@@ -214,7 +231,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
@@ -336,20 +353,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
frames = []
|
||||
for i in range(num_seconds):
|
||||
# Whether or not to clear fake context parallel cache
|
||||
fake_cp = i + 1 < num_seconds
|
||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||
|
||||
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample
|
||||
frames.append(current_frames)
|
||||
|
||||
frames = torch.cat(frames, dim=2)
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
@@ -422,6 +430,46 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -442,11 +490,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 48,
|
||||
fps: int = 8,
|
||||
num_frames: int = 49,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@@ -459,7 +507,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
use_dynamic_cfg: bool = False,
|
||||
max_sequence_length: int = 226,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -525,6 +573,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -534,9 +585,10 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
assert (
|
||||
num_frames <= 48 and num_frames % fps == 0 and fps == 8
|
||||
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@@ -581,6 +633,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
@@ -592,7 +645,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
num_frames += 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
@@ -608,7 +660,14 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -629,6 +688,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
@@ -671,8 +731,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latents":
|
||||
video = self.decode_latents(latents, num_frames // fps)
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
@@ -76,13 +76,13 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
||||
>>> from transformers import DPTImageProcessor, DPTForDepthEstimation
|
||||
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
|
||||
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
... "diffusers/controlnet-depth-sdxl-1.0-small",
|
||||
... variant="fp16",
|
||||
|
||||
@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
@@ -149,7 +149,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
|
||||
],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user