Compare commits

..

3 Commits

Author SHA1 Message Date
Dhruv Nair a3584b7ad0 update 2024-08-06 12:05:26 +00:00
Dhruv Nair 769e40eb06 update 2024-07-30 06:33:31 +00:00
Dhruv Nair 70fc9394de update 2024-07-30 06:22:34 +00:00
132 changed files with 1916 additions and 19021 deletions
+3 -3
View File
@@ -13,13 +13,13 @@ env:
jobs:
torch_pipelines_cuda_benchmark_tests:
env:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
name: Torch Core Pipelines CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on:
runs-on:
group: aws-g6-4xlarge-plus
container:
image: diffusers/diffusers-pytorch-compile-cuda
@@ -59,7 +59,7 @@ jobs:
if: ${{ success() }}
run: |
pip install requests && python utils/notify_benchmarking_status.py --status=success
- name: Report failure status
if: ${{ failure() }}
run: |
@@ -24,7 +24,7 @@ jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-latest
steps:
# Checkout to correct ref
@@ -95,7 +95,7 @@ jobs:
if: ${{ success() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=success
- name: Report failure status
if: ${{ failure() }}
run: |
+1 -1
View File
@@ -32,7 +32,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
pip install -e .[test]
pip install -e .
pip install huggingface_hub
- name: Fetch Pipeline Matrix
id: fetch_pipeline_matrix
+1 -1
View File
@@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
**NOTE about channels**:
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
+2 -2
View File
@@ -67,7 +67,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
## Quickstart
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 30,000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 27.000+ checkpoints):
```python
from diffusers import DiffusionPipeline
@@ -209,7 +209,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/deep-floyd/IF
- https://github.com/bentoml/BentoML
- https://github.com/bmaltais/kohya_ss
- +14,000 other amazing GitHub repositories 💪
- +12.000 other amazing GitHub repositories 💪
Thank you for using us ❤️.
+4 -12
View File
@@ -190,6 +190,10 @@
- local: conceptual/evaluation
title: Evaluating Diffusion Models
title: Conceptual Guides
- sections:
- local: community_projects
title: Projects built with Diffusers
title: Community Projects
- sections:
- isExpanded: false
sections:
@@ -239,8 +243,6 @@
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/autoencoder_oobleck
title: Oobleck AutoEncoder
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/transformer2d
@@ -253,8 +255,6 @@
title: HunyuanDiT2DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
@@ -263,8 +263,6 @@
title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/controlnet
@@ -322,8 +320,6 @@
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/i2vgenxl
@@ -370,8 +366,6 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
@@ -435,8 +429,6 @@
title: CMStochasticIterativeScheduler
- local: api/schedulers/consistency_decoder
title: ConsistencyDecoderScheduler
- local: api/schedulers/cosine_dpm
title: CosineDPMSolverMultistepScheduler
- local: api/schedulers/ddim_inverse
title: DDIMInverseScheduler
- local: api/schedulers/ddim
@@ -22,7 +22,6 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
## Supported pipelines
- [`CogVideoXPipeline`]
- [`StableDiffusionPipeline`]
- [`StableDiffusionImg2ImgPipeline`]
- [`StableDiffusionInpaintPipeline`]
@@ -50,7 +49,6 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`UNet2DConditionModel`]
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`AutoencoderKLCogVideoX`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]
@@ -1,38 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoencoderOobleck
The Oobleck variational autoencoder (VAE) model with KL loss was introduced in [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) and [Stable Audio Open](https://huggingface.co/papers/2407.14358) by Stability AI. The model is used in 🤗 Diffusers to encode audio waveforms into latents and to decode latent representations into audio waveforms.
The abstract from the paper is:
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
## AutoencoderOobleck
[[autodoc]] AutoencoderOobleck
- decode
- encode
- all
## OobleckDecoderOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput
## OobleckDecoderOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput
## AutoencoderOobleckOutput
[[autodoc]] models.autoencoders.autoencoder_oobleck.AutoencoderOobleckOutput
@@ -1,69 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLCogVideoX
The 3D variational autoencoder (VAE) model with KL loss using CogVideoX.
## Loading from the original format
By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
from diffusers import AutoencoderKLCogVideoX
url = "THUDM/CogVideoX-2b" # can also be a local file
model = AutoencoderKLCogVideoX.from_single_file(url)
```
## AutoencoderKLCogVideoX
[[autodoc]] AutoencoderKLCogVideoX
- decode
- encode
- all
## CogVideoXSafeConv3d
[[autodoc]] CogVideoXSafeConv3d
## CogVideoXCausalConv3d
[[autodoc]] CogVideoXCausalConv3d
## CogVideoXSpatialNorm3D
[[autodoc]] CogVideoXSpatialNorm3D
## CogVideoXResnetBlock3D
[[autodoc]] CogVideoXResnetBlock3D
## CogVideoXDownBlock3D
[[autodoc]] CogVideoXDownBlock3D
## CogVideoXMidBlock3D
[[autodoc]] CogVideoXMidBlock3D
## CogVideoXUpBlock3D
[[autodoc]] CogVideoXUpBlock3D
## CogVideoXEncoder3D
[[autodoc]] CogVideoXEncoder3D
## CogVideoXDecoder3D
[[autodoc]] CogVideoXDecoder3D
@@ -1,18 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
## CogVideoXTransformer3DModel
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX).
## CogVideoXTransformer3DModel
[[autodoc]] CogVideoXTransformer3DModel
@@ -1,19 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FluxTransformer2DModel
A Transformer model for image-like data from [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).
## FluxTransformer2DModel
[[autodoc]] FluxTransformer2DModel
@@ -1,19 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# StableAudioDiTModel
A Transformer model for audio waveforms from [Stable Audio Open](https://huggingface.co/papers/2407.14358).
## StableAudioDiTModel
[[autodoc]] StableAudioDiTModel
@@ -25,9 +25,6 @@ The abstract of the paper is the following:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
| [AnimateDiffControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py) | *Controlled Video-to-Video Generation with AnimateDiff using ControlNet* |
| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
## Available checkpoints
@@ -103,83 +100,6 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
</Tip>
### AnimateDiffControlNetPipeline
AnimateDiff can also be used with ControlNets ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide depth maps, the ControlNet model generates a video that'll preserve the spatial information from the depth maps. It is a more flexible and accurate way to control the video generation process.
```python
import torch
from diffusers import AnimateDiffControlNetPipeline, AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
from diffusers.utils import export_to_gif, load_video
# Additionally, you will need a preprocess videos before they can be used with the ControlNet
# HF maintains just the right package for it: `pip install controlnet_aux`
from controlnet_aux.processor import ZoeDetector
# Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
# Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)
# We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
"SG161222/Realistic_Vision_V5.1_noVAE",
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])
depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
conditioning_frames = []
with pipe.progress_bar(total=len(video)) as progress_bar:
for frame in video:
conditioning_frames.append(depth_detector(frame))
progress_bar.update()
prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
negative_prompt = "bad quality, worst quality"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=len(video),
num_inference_steps=10,
guidance_scale=2.0,
conditioning_frames=conditioning_frames,
generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_gif(video, "animatediff_controlnet.gif", fps=8)
```
Here are some sample outputs:
<table align="center">
<tr>
<th align="center">Source Video</th>
<th align="center">Output Video</th>
</tr>
<tr>
<td align="center">
raccoon playing a guitar
<br />
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" alt="racoon playing a guitar" />
</td>
<td align="center">
a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality
<br/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-controlnet-output.gif" alt="a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" />
</td>
</tr>
</table>
### AnimateDiffSparseControlNetPipeline
[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.
@@ -842,12 +762,6 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
- all
- __call__
## AnimateDiffControlNetPipeline
[[autodoc]] AnimateDiffControlNetPipeline
- all
- __call__
## AnimateDiffSparseControlNetPipeline
[[autodoc]] AnimateDiffSparseControlNetPipeline
+1 -1
View File
@@ -18,7 +18,7 @@ It was developed by the Fal team and more details about it can be found in [this
<Tip>
AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
</Tip>
-79
View File
@@ -1,79 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## TODO: The paper is still being written.
-->
# CogVideoX
[TODO]() from Tsinghua University & ZhipuAI.
The abstract from the paper is:
The paper is still being written.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
### Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
First, load the pipeline:
```python
import torch
from diffusers import LattePipeline
pipeline = LattePipeline.from_pretrained(
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
).to("cuda")
```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```
Finally, compile the components and run inference:
```python
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
# CogVideoX works very well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```
The [benchmark](TODO: link) results on an 80GB A100 machine are:
```
Without torch.compile(): Average inference time: TODO seconds.
With torch.compile(): Average inference time: TODO seconds.
```
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
- all
- __call__
## CogVideoXPipelineOutput
[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput
-84
View File
@@ -1,84 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Flux
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux).
<Tip>
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
</Tip>
Flux comes in two variants:
* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
Both checkpoints have slightly difference usage which we detail below.
### Timestep-distilled
* `max_sequence_length` cannot be more than 256.
* `guidance_scale` needs to be 0.
* As this is a timestep-distilled model, it benefits from fewer sampling steps.
```python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
```
### Guidance-distilled
* The guidance-distilled variant takes about 50 sampling steps for good-quality generation.
* It doesn't have any limitations around the `max_sequence_length`.
```python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
prompt=prompt,
guidance_scale=3.5,
height=768,
width=1360,
num_inference_steps=50,
).images[0]
out.save("image.png")
```
## FluxPipeline
[[autodoc]] FluxPipeline
- all
- __call__
+2 -2
View File
@@ -59,7 +59,7 @@ First, load the pipeline:
```python
from diffusers import LuminaText2ImgPipeline
import torch
import torch
pipeline = LuminaText2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
@@ -87,4 +87,4 @@ image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit w
[[autodoc]] LuminaText2ImgPipeline
- all
- __call__
-1
View File
@@ -71,7 +71,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Semantic Guidance](semantic_stable_diffusion) | text2image |
| [Shap-E](shap_e) | text-to-3D, image-to-3D |
| [Spectrogram Diffusion](spectrogram_diffusion) | |
| [Stable Audio](stable_audio) | text2audio |
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
| [Stable Diffusion Model Editing](model_editing) | model editing |
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
-29
View File
@@ -20,29 +20,6 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`
<Tip warning={true}>
Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
</Tip>
## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__
## HunyuanDiTPAGPipeline
[[autodoc]] HunyuanDiTPAGPipeline
- all
- __call__
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
@@ -72,9 +49,3 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
[[autodoc]] StableDiffusionXLControlNetPAGPipeline
- all
- __call__
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
- __call__
@@ -1,42 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Audio
Stable Audio was proposed in [Stable Audio Open](https://arxiv.org/abs/2407.14358) by Zach Evans et al. . it takes a text prompt as input and predicts the corresponding sound or music sample.
Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder.
Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT.
The abstract of the paper is the following:
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool).
## Tips
When constructing a prompt, keep in mind:
* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".
During inference:
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
## StableAudioPipeline
[[autodoc]] StableAudioPipeline
- all
- __call__
@@ -1,24 +0,0 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# CosineDPMSolverMultistepScheduler
The [`CosineDPMSolverMultistepScheduler`] is a variant of [`DPMSolverMultistepScheduler`] with cosine schedule, proposed by Nichol and Dhariwal (2021).
It is being used in the [Stable Audio Open](https://arxiv.org/abs/2407.14358) paper and the [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool) codebase.
This scheduler was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe).
## CosineDPMSolverMultistepScheduler
[[autodoc]] CosineDPMSolverMultistepScheduler
## SchedulerOutput
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
+78
View File
@@ -0,0 +1,78 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Community Projects
Welcome to Community Projects. This space is dedicated to showcasing the incredible work and innovative applications created by our vibrant community using the `diffusers` library.
This section aims to:
- Highlight diverse and inspiring projects built with `diffusers`
- Foster knowledge sharing within our community
- Provide real-world examples of how `diffusers` can be leveraged
Happy exploring, and thank you for being part of the Diffusers community!
<table>
<tr>
<th>Project Name</th>
<th>Description</th>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
<td>Stable Diffusion built-in to Blender</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/megvii-research/HiDiffusion"> HiDiffusion </a></td>
<td>Increases the resolution and speed of your diffusion model by only adding a single line of code</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/lllyasviel/IC-Light"> IC-Light </a></td>
<td>IC-Light is a project to manipulate the illumination of images</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/InstantID/InstantID"> InstantID </a></td>
<td>InstantID : Zero-shot Identity-Preserving Generation in Seconds</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/Sanster/IOPaint"> IOPaint </a></td>
<td>Image inpainting tool powered by SOTA AI Model. Remove any unwanted object, defect, people from your pictures or erase and replace(powered by stable diffusion) any thing on your pictures.</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/bmaltais/kohya_ss"> Kohya </a></td>
<td>Gradio GUI for Kohya's Stable Diffusion trainers</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/magic-research/magic-animate"> MagicAnimate </a></td>
<td>MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/levihsu/OOTDiffusion"> OOTDiffusion </a></td>
<td>Outfitting Fusion based Latent Diffusion for Controllable Virtual Try-on</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/vladmandic/automatic"> SD.Next </a></td>
<td>SD.Next: Advanced Implementation of Stable Diffusion and other Diffusion-based generative image models</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/ashawkey/stable-dreamfusion"> stable-dreamfusion </a></td>
<td>Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/HVision-NKU/StoryDiffusion"> StoryDiffusion </a></td>
<td>StoryDiffusion can create a magic story by generating consistent images and videos.</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/cumulo-autumn/StreamDiffusion"> StreamDiffusion </a></td>
<td>A Pipeline-Level Solution for Real-Time Interactive Generation</td>
</tr>
</table>
+2 -2
View File
@@ -35,7 +35,7 @@ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu
```
> [!TIP]
> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
> If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
@@ -168,7 +168,7 @@ Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3
</div>
> [!TIP]
> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
### Prevent graph breaks
@@ -18,13 +18,13 @@ A modern diffusion model, like [Stable Diffusion XL (SDXL)](../using-diffusers/s
* Two text encoders
* A UNet for denoising
Usually, the text encoders and the denoiser are much larger compared to the VAE.
Usually, the text encoders and the denoiser are much larger compared to the VAE.
As models get bigger and better, its possible your model is so big that even a single copy wont fit in memory. But that doesnt mean it cant be loaded. If you have more than one GPU, there is more memory available to store your model. In this case, its better to split your model checkpoint into several smaller *checkpoint shards*.
When a text encoder checkpoint has multiple shards, like [T5-xxl for SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3), it is automatically handled by the [Transformers](https://huggingface.co/docs/transformers/index) library as it is a required dependency of Diffusers when using the [`StableDiffusion3Pipeline`]. More specifically, Transformers will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed.
The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library.
The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library.
> [!TIP]
> Refer to the [Handling big models for inference](https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference) guide for general guidance when working with big models that are hard to fit into memory.
@@ -43,7 +43,7 @@ unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
```python
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
import torch
unet = UNet2DConditionModel.from_pretrained(
@@ -57,14 +57,14 @@ image = pipeline("a cute dog running on the grass", num_inference_steps=30).imag
image.save("dog.png")
```
If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you:
If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you:
```diff
- pipeline.to("cuda")
+ pipeline.enable_model_cpu_offload()
```
In general, we recommend sharding when a checkpoint is more than 5GB (in fp32).
In general, we recommend sharding when a checkpoint is more than 5GB (in fp32).
## Device placement
+1 -1
View File
@@ -256,7 +256,7 @@ make_image_grid([init_image, mask_image, output], rows=1, cols=3)
## Guess mode
[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) does not require supplying a prompt to a ControlNet at all! This forces the ControlNet encoder to do its best to "guess" the contents of the input control map (depth map, pose estimation, canny edge, etc.).
[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) does not require supplying a prompt to a ControlNet at all! This forces the ControlNet encoder to do it's best to "guess" the contents of the input control map (depth map, pose estimation, canny edge, etc.).
Guess mode adjusts the scale of the output residuals from a ControlNet by a fixed ratio depending on the block depth. The shallowest `DownBlock` corresponds to 0.1, and as the blocks get deeper, the scale increases exponentially such that the scale of the `MidBlock` output becomes 1.0.
+9 -9
View File
@@ -22,7 +22,7 @@ This guide will show you how to use PAG for various tasks and use cases.
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
> [!TIP]
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines and [`PixArtSigmaPAGPipeline`]. But feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
<hfoptions id="tasks">
<hfoption id="Text-to-image">
@@ -130,10 +130,10 @@ prompt = "a dog catching a frisbee in the jungle"
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipeline(
prompt,
image=init_image,
strength=0.8,
guidance_scale=guidance_scale,
prompt,
image=init_image,
strength=0.8,
guidance_scale=guidance_scale,
pag_scale=pag_scale,
generator=generator).images[0]
```
@@ -161,14 +161,14 @@ pipeline_inpaint = AutoPipelineForInpaiting.from_pretrained("stabilityai/stable-
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_inpaint, enable_pag=True)
```
This still works when your pipeline has a different task:
This still works when your pipeline has a different task:
```py
pipeline_t2i = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_t2i, enable_pag=True)
```
Let's generate an image!
Let's generate an image!
```py
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
@@ -258,7 +258,7 @@ for pag_scale in [0.0, 3.0]:
</div>
</div>
## PAG with IP-Adapter
## PAG with IP-Adapter
[IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-Adapter loaded.
@@ -317,7 +317,7 @@ PAG reduces artifacts and improves the overall compposition.
</div>
## Configure parameters
## Configure parameters
### pag_applied_layers
+1 -1
View File
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# 철학 [[philosophy]]
# 철학 [[philosophy]]
🧨 Diffusers는 다양한 모달리티에서 **최신의** 사전 훈련된 diffusion 모델을 제공합니다.
그 목적은 추론과 훈련을 위한 **모듈식 툴박스**로 사용되는 것입니다.
+2 -2
View File
@@ -52,7 +52,7 @@ pipeline = pipeline.to("cuda")
Text-to-image의 경우 텍스트 프롬프트를 전달합니다. 기본적으로 SDXL Turbo는 512x512 이미지를 생성하며, 이 해상도에서 최상의 결과를 제공합니다. `height``width` 매개 변수를 768x768 또는 1024x1024로 설정할 수 있지만 이 경우 품질 저하를 예상할 수 있습니다.
모델이 `guidance_scale` 없이 학습되었으므로 이를 0.0으로 설정해 비활성화해야 합니다. 단일 추론 스텝만으로도 고품질 이미지를 생성할 수 있습니다.
모델이 `guidance_scale` 없이 학습되었으므로 이를 0.0으로 설정해 비활성화해야 합니다. 단일 추론 스텝만으로도 고품질 이미지를 생성할 수 있습니다.
스텝 수를 2, 3 또는 4로 늘리면 이미지 품질이 향상됩니다.
```py
@@ -74,7 +74,7 @@ image
## Image-to-image
Image-to-image 생성의 경우 `num_inference_steps * strength`가 1보다 크거나 같은지 확인하세요.
Image-to-image 생성의 경우 `num_inference_steps * strength`가 1보다 크거나 같은지 확인하세요.
Image-to-image 파이프라인은 아래 예제에서 `0.5 * 2.0 = 1` 스텝과 같이 `int(num_inference_steps * strength)` 스텝으로 실행됩니다.
```py
+1 -1
View File
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:
```py
!pip install -q -U diffusers transformers accelerate
!pip install -q -U diffusers transformers accelerate
```
이 모델에는 [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)와 [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) 두 가지 종류가 있습니다. SVD 체크포인트는 14개의 프레임을 생성하도록 학습되었고, SVD-XT 체크포인트는 25개의 프레임을 생성하도록 파인튜닝되었습니다.
+10 -9
View File
@@ -1487,16 +1487,17 @@ NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
```python
import torch
from diffusers import DDIMScheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
# Use the DDIMScheduler scheduler here instead
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_txt2img",
variant='fp16',
torch_dtype=torch.float16,
scheduler=scheduler,)
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_txt2img",
variant='fp16',
torch_dtype=torch.float16,
scheduler=scheduler,)
# re-use cached folder to save ONNX models and TensorRT Engines
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", variant='fp16',)
@@ -2230,12 +2231,12 @@ from io import BytesIO
from PIL import Image
import torch
from diffusers import PNDMScheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
# Use the PNDMScheduler scheduler here instead
scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
custom_pipeline="stable_diffusion_tensorrt_inpaint",
variant='fp16',
torch_dtype=torch.float16,
+1 -1
View File
@@ -2436,7 +2436,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1002,7 +1002,7 @@ class StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2I
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -991,7 +991,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+2 -2
View File
@@ -864,7 +864,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
)
if guess_mode and do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1038,7 +1038,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
)
if guess_mode and do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [
@@ -752,7 +752,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
)
if guess_mode and do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -60,7 +60,7 @@ from diffusers.utils import logging
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade tensorrt-cu12==10.2.0
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -659,7 +659,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
@@ -18,7 +18,8 @@
import gc
import os
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from copy import copy
from typing import List, Optional, Union
import numpy as np
import onnx
@@ -26,11 +27,9 @@ import onnx_graphsurgeon as gs
import PIL.Image
import tensorrt as trt
import torch
from cuda import cudart
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from packaging import version
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants
@@ -42,29 +41,24 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
StableDiffusionInpaintPipeline,
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
prepare_mask_and_masked_image,
retrieve_latents,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade tensorrt>=8.6.1
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -94,6 +88,10 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
def preprocess_image(image):
"""
image: torch.Tensor
@@ -127,8 +125,10 @@ class Engine:
onnx_path,
fp16,
input_profile=None,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
workspace_size=0,
):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
@@ -137,13 +137,20 @@ class Engine:
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
extra_build_args = {}
config_kwargs = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics:
extra_build_args["tactic_sources"] = []
config_kwargs["tactic_sources"] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
@@ -156,24 +163,28 @@ class Engine:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for binding in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(binding)
if shape_dict and name in shape_dict:
shape = shape_dict[name]
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[name] = tensor
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v3(stream)
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
if not noerror:
raise ValueError("ERROR: inference failed.")
@@ -314,8 +325,10 @@ def build_engines(
force_engine_rebuild=False,
static_batch=False,
static_shape=True,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
max_workspace_size=0,
):
built_engines = {}
if not os.path.isdir(onnx_dir):
@@ -380,7 +393,9 @@ def build_engines(
static_batch=static_batch,
static_shape=static_shape,
),
enable_preview=enable_preview,
timing_cache=timing_cache,
workspace_size=max_workspace_size,
)
built_engines[model_name] = engine
@@ -659,11 +674,11 @@ def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False)
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
r"""
Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
@@ -687,8 +702,6 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
def __init__(
self,
vae: AutoencoderKL,
@@ -709,86 +722,24 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
onnx_dir: str = "onnx",
# TensorRT engine build parameters
engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache",
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
self.stages = stages
self.image_height, self.image_width = image_height, image_width
self.inpaint = True
@@ -799,6 +750,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
self.timing_cache = timing_cache
self.build_static_batch = False
self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
@@ -809,11 +761,6 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self):
# Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size
@@ -832,112 +779,6 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
if "vae_encoder" in self.stages:
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self._encode_vae_image(image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -985,6 +826,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache,
)
@@ -1008,7 +850,9 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
return tuple(init_images)
def __encode_image(self, init_image):
init_latents = runEngine(self.engine["vae_encoder"], {"images": init_image}, self.stream)["latent"]
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[
"latent"
]
init_latents = 0.18215 * init_latents
return init_latents
@@ -1037,8 +881,9 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
.to(self.torch_device)
)
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
"text_embeddings"
].clone()
@@ -1054,7 +899,8 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
.input_ids.type(torch.int32)
.to(self.torch_device)
)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
uncond_input_ids_inp = device_view(uncond_input_ids)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
"text_embeddings"
]
@@ -1078,15 +924,18 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
# Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine(
self.engine["unet"],
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
self.stream,
)["latent"]
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
@@ -1094,12 +943,12 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
return latents
def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size):
self.stream = cudart.cudaStreamCreate()[1]
self.stream = cuda.Stream()
# Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items():
@@ -1263,6 +1112,5 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
# VAE decode latent
images = self.__decode_latent(latents)
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
images = self.numpy_to_pil(images)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None)
@@ -18,19 +18,17 @@
import gc
import os
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from copy import copy
from typing import List, Optional, Union
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import PIL.Image
import tensorrt as trt
import torch
from cuda import cudart
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from packaging import version
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.onnx.loader import fold_constants
@@ -42,25 +40,23 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
"""
Installation instructions
python3 -m pip install --upgrade transformers diffusers>=0.16.0
python3 -m pip install --upgrade tensorrt~=10.2.0
python3 -m pip install --upgrade tensorrt>=8.6.1
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
python3 -m pip install onnxruntime
"""
@@ -90,6 +86,10 @@ else:
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
def device_view(t):
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
class Engine:
def __init__(self, engine_path):
self.engine_path = engine_path
@@ -110,8 +110,10 @@ class Engine:
onnx_path,
fp16,
input_profile=None,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
workspace_size=0,
):
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
@@ -120,13 +122,20 @@ class Engine:
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
extra_build_args = {}
config_kwargs = {}
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if enable_preview:
# Faster dynamic shapes made optional since it increases engine build time.
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
if workspace_size > 0:
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
if not enable_all_tactics:
extra_build_args["tactic_sources"] = []
config_kwargs["tactic_sources"] = []
engine = engine_from_network(
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
@@ -139,24 +148,28 @@ class Engine:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for binding in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(binding)
if shape_dict and name in shape_dict:
shape = shape_dict[name]
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[name] = tensor
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
noerror = self.context.execute_async_v3(stream)
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
if not noerror:
raise ValueError("ERROR: inference failed.")
@@ -297,8 +310,10 @@ def build_engines(
force_engine_rebuild=False,
static_batch=False,
static_shape=True,
enable_preview=False,
enable_all_tactics=False,
timing_cache=None,
max_workspace_size=0,
):
built_engines = {}
if not os.path.isdir(onnx_dir):
@@ -363,7 +378,9 @@ def build_engines(
static_batch=static_batch,
static_shape=static_shape,
),
enable_preview=enable_preview,
timing_cache=timing_cache,
workspace_size=max_workspace_size,
)
built_engines[model_name] = engine
@@ -571,11 +588,11 @@ def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
class TensorRTStableDiffusionPipeline(DiffusionPipeline):
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
@@ -599,8 +616,6 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
@@ -617,90 +632,28 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
image_width: int = 768,
max_batch_size: int = 16,
# ONNX export parameters
onnx_opset: int = 18,
onnx_opset: int = 17,
onnx_dir: str = "onnx",
# TensorRT engine build parameters
engine_dir: str = "engine",
build_preview_features: bool = True,
force_engine_rebuild: bool = False,
timing_cache: str = "timing_cache",
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
self.stages = stages
self.image_height, self.image_width = image_height, image_width
self.inpaint = False
@@ -711,6 +664,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
self.timing_cache = timing_cache
self.build_static_batch = False
self.build_dynamic_shape = False
self.build_preview_features = build_preview_features
self.max_batch_size = max_batch_size
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
@@ -721,11 +675,6 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
self.models = {} # loaded in __loadModels()
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def __loadModels(self):
# Load pipeline models
self.embedding_dim = self.text_encoder.config.hidden_size
@@ -742,75 +691,6 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
if "vae" in self.stages:
self.models["vae"] = make_VAE(self.vae, **models_args)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Union[torch.Generator, List[torch.Generator]],
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Prepare the latent vectors for diffusion.
Args:
batch_size (int): The number of samples in the batch.
num_channels_latents (int): The number of channels in the latent vectors.
height (int): The height of the latent vectors.
width (int): The width of the latent vectors.
dtype (torch.dtype): The data type of the latent vectors.
device (torch.device): The device to place the latent vectors on.
generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.
latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.
Returns:
torch.Tensor: The prepared latent vectors.
"""
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(
self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype
) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:
r"""
Runs the safety checker on the given image.
Args:
image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.
device (torch.device): The device to run the safety checker on.
dtype (torch.dtype): The data type of the input image.
Returns:
(image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and
a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.
"""
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -858,6 +738,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
force_engine_rebuild=self.force_engine_rebuild,
static_batch=self.build_static_batch,
static_shape=not self.build_dynamic_shape,
enable_preview=self.build_preview_features,
timing_cache=self.timing_cache,
)
@@ -888,8 +769,9 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
.to(self.torch_device)
)
text_input_ids_inp = device_view(text_input_ids)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids}, self.stream)[
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
"text_embeddings"
].clone()
@@ -905,7 +787,8 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
.input_ids.type(torch.int32)
.to(self.torch_device)
)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids}, self.stream)[
uncond_input_ids_inp = device_view(uncond_input_ids)
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
"text_embeddings"
]
@@ -929,15 +812,18 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
# Predict the noise residual
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
sample_inp = device_view(latent_model_input)
timestep_inp = device_view(timestep_float)
embeddings_inp = device_view(text_embeddings)
noise_pred = runEngine(
self.engine["unet"],
{"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings},
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
self.stream,
)["latent"]
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
@@ -945,12 +831,12 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
return latents
def __decode_latent(self, latents):
images = runEngine(self.engine["vae"], {"latent": latents}, self.stream)["images"]
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
images = (images / 2 + 0.5).clamp(0, 1)
return images.cpu().permute(0, 2, 3, 1).float().numpy()
def __loadResources(self, image_height, image_width, batch_size):
self.stream = cudart.cudaStreamCreate()[1]
self.stream = cuda.Stream()
# Allocate buffers for TensorRT engine bindings
for model_name, obj in self.models.items():
+4 -4
View File
@@ -148,12 +148,12 @@ accelerate launch train_dreambooth_lora_sd3.py \
```
### Text Encoder Training
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
@@ -185,4 +185,4 @@ accelerate launch train_dreambooth_lora_sd3.py \
1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
@@ -46,4 +46,5 @@ pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(0)
image = pipe("a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
```
@@ -2051,7 +2051,7 @@ if __name__ == "__main__":
default=512,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2"
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
@@ -1253,7 +1253,7 @@ class PromptDiffusionPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -11,28 +11,28 @@ huggingface-cli login
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
## How
We make use of several techniques to make this possible:
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
* 8bit Adam for optimization through the `bitsandbytes` library.
* Gradient checkpointing and gradient accumulation.
* FP16 precision.
* Flash attention through `F.scaled_dot_product_attention()`.
* Flash attention through `F.scaled_dot_product_attention()`.
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
## Gotchas
This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
* Training of text encoders is purposefully disabled.
* Techniques such as prior-preservation is unsupported.
* Training of text encoders is purposefully disabled.
* Techniques such as prior-preservation is unsupported.
* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
Hopefully, this project gives you a template to extend it further to suit your needs.
-222
View File
@@ -1,222 +0,0 @@
import argparse
from typing import Any, Dict
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
to_q_key = key.replace("query_key_value", "to_q")
to_k_key = key.replace("query_key_value", "to_k")
to_v_key = key.replace("query_key_value", "to_v")
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
state_dict[to_q_key] = to_q
state_dict[to_k_key] = to_k
state_dict[to_v_key] = to_v
state_dict.pop(key)
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, weight_or_bias = key.split(".")[-2:]
if "query" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
elif "key" in key:
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
state_dict[new_key] = state_dict.pop(key)
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
layer_id, _, weight_or_bias = key.split(".")[-3:]
weights_or_biases = state_dict[key].chunk(12, dim=0)
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
state_dict[norm1_key] = norm1_weights_or_biases
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
state_dict[norm2_key] = norm2_weights_or_biases
state_dict.pop(key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
key_split = key.split(".")
layer_index = int(key_split[2])
replace_layer_index = 4 - 1 - layer_index
key_split[1] = "up_blocks"
key_split[2] = str(replace_layer_index)
new_key = ".".join(key_split)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"transformer.final_layernorm": "norm_final",
"transformer": "transformer_blocks",
"attention": "attn1",
"mlp": "ff.net",
"dense_h_to_4h": "0.proj",
"dense_4h_to_h": "2",
".layers": "",
"dense": "to_out.0",
"input_layernorm": "norm1.norm",
"post_attn1_layernorm": "norm2.norm",
"time_embed.0": "time_embedding.linear_1",
"time_embed.2": "time_embedding.linear_2",
"mixins.patch_embed": "patch_embed",
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"query_key_value": reassign_query_key_value_inplace,
"query_layernorm_list": reassign_query_key_layernorm_inplace,
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
"block.": "resnets.",
"down.": "down_blocks.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"nin_shortcut": "conv_shortcut",
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
}
VAE_SPECIAL_KEYS_REMAP = {
"loss": remove_keys_inplace,
"up.": replace_up_keys_inplace,
}
TOKENIZER_MAX_LENGTH = 226
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(ckpt_path: str):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
vae = AutoencoderKLCogVideoX()
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 3.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "linspace",
}
)
pipe = CogVideoXPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
-303
View File
@@ -1,303 +0,0 @@
import argparse
from contextlib import nullcontext
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers import AutoencoderKL, FluxTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
"""
# Transformer
python scripts/convert_flux_to_diffusers.py \
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
--filename "flux1-schnell.sft"
--output_path "flux-schnell" \
--transformer
"""
"""
# VAE
python scripts/convert_flux_to_diffusers.py \
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
--filename "ae.sft"
--output_path "flux-schnell" \
--vae
"""
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--transformer", action="store_true")
parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_flux_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
):
converted_state_dict = {}
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_in.in_layer.bias"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_in.out_layer.bias"
)
## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
"vector_in.in_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
"vector_in.in_layer.bias"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
"vector_in.out_layer.bias"
)
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
"guidance_in.out_layer.bias"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
)
context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transfomer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.weight"
)
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.bias"
)
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
has_guidance = any("guidance" in k for k in original_ckpt)
if args.transformer:
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
print(
f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
)
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
if args.vae:
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
if __name__ == "__main__":
main(args)
@@ -42,7 +42,7 @@ if __name__ == "__main__":
default=512,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2"
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
@@ -67,7 +67,7 @@ if __name__ == "__main__":
default=None,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2"
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
-279
View File
@@ -1,279 +0,0 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import json
import os
from contextlib import nullcontext
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
T5EncoderModel,
)
from diffusers import (
AutoencoderOobleck,
CosineDPMSolverMultistepScheduler,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5):
projection_model_state_dict = {
k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v
for (k, v) in state_dict.items()
if "conditioner.conditioners" in k
}
# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is.
for key, value in list(projection_model_state_dict.items()):
new_key = key.replace("seconds_start", "start_number_conditioner").replace(
"seconds_total", "end_number_conditioner"
)
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key)
model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k}
for key, value in list(model_state_dict.items()):
# attention layers
new_key = (
key.replace("transformer.", "")
.replace("layers", "transformer_blocks")
.replace("self_attn", "attn1")
.replace("cross_attn", "attn2")
.replace("ff.ff", "ff.net")
)
new_key = (
new_key.replace("pre_norm", "norm1")
.replace("cross_attend_norm", "norm2")
.replace("ff_norm", "norm3")
.replace("to_out", "to_out.0")
)
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm
# other layers
new_key = (
new_key.replace("project", "proj")
.replace("to_timestep_embed", "timestep_proj")
.replace("timestep_features", "time_proj")
.replace("to_global_embed", "global_proj")
.replace("to_cond_embed", "cross_attention_proj")
)
# we're using diffusers implementation of time_proj (GaussianFourierProjection) which creates a 1D tensor
if new_key == "time_proj.weight":
model_state_dict[key] = model_state_dict[key].squeeze(1)
if "to_qkv" in new_key:
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0)
model_state_dict[new_key.replace("qkv", "q")] = q
model_state_dict[new_key.replace("qkv", "k")] = k
model_state_dict[new_key.replace("qkv", "v")] = v
elif "to_kv" in new_key:
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0)
model_state_dict[new_key.replace("kv", "k")] = k
model_state_dict[new_key.replace("kv", "v")] = v
else:
model_state_dict[new_key] = model_state_dict.pop(key)
autoencoder_state_dict = {
k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v
for (k, v) in state_dict.items()
if "pretransform.model." in k
}
for key, _ in list(autoencoder_state_dict.items()):
new_key = key
if "coder.layers" in new_key:
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
if "encoder" in new_key:
for i in range(3):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
else:
for i in range(2, 5):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
new_key = new_key.replace("layers.2.beta", "snake2.beta")
new_key = new_key.replace("layers.2.alpha", "snake2.alpha")
new_key = new_key.replace("layers.1.bias", "conv1.bias")
new_key = new_key.replace("layers.1.weight_", "conv1.weight_")
new_key = new_key.replace("layers.3.bias", "conv2.bias")
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx-1}", "snake1")
elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx-1}", "conv2")
else:
new_key = new_key
value = autoencoder_state_dict.pop(key)
if "snake" in new_key:
value = value.unsqueeze(0).unsqueeze(-1)
if new_key in autoencoder_state_dict:
raise ValueError(f"{new_key} already in state dict.")
autoencoder_state_dict[new_key] = value
return model_state_dict, projection_model_state_dict, autoencoder_state_dict
parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline")
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--save_directory",
type=str,
default="./tmp/stable-audio-1.0",
help="Directory to save a pipeline to. Will be created if it doesn't exist.",
)
parser.add_argument(
"--repo_id",
type=str,
default="stable-audio-1.0",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
args = parser.parse_args()
checkpoint_path = (
os.path.join(args.model_folder_path, "model.safetensors")
if args.use_safetensors
else os.path.join(args.model_folder_path, "model.ckpt")
)
config_path = os.path.join(args.model_folder_path, "model_config.json")
device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
with open(config_path) as f_in:
config_dict = json.load(f_in)
conditioning_dict = {
conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"]
}
t5_model_config = conditioning_dict["prompt"]
# T5 Text encoder
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"])
tokenizer = AutoTokenizer.from_pretrained(
t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"]
)
# scheduler
scheduler = CosineDPMSolverMultistepScheduler(
sigma_min=0.3,
sigma_max=500,
solver_order=2,
prediction_type="v_prediction",
sigma_data=1.0,
sigma_schedule="exponential",
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
if args.use_safetensors:
orig_state_dict = load_file(checkpoint_path, device=device)
else:
orig_state_dict = torch.load(checkpoint_path, map_location=device)
model_config = config_dict["model"]["diffusion"]["config"]
model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers(
orig_state_dict
)
with ctx():
projection_model = StableAudioProjectionModel(
text_encoder_dim=text_encoder.config.d_model,
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"],
min_value=conditioning_dict["seconds_start"][
"min_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
max_value=conditioning_dict["seconds_start"][
"max_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
)
if is_accelerate_available():
load_model_dict_into_meta(projection_model, projection_model_state_dict)
else:
projection_model.load_state_dict(projection_model_state_dict)
attention_head_dim = model_config["embed_dim"] // model_config["num_heads"]
with ctx():
model = StableAudioDiTModel(
sample_size=int(config_dict["sample_size"])
/ int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]),
in_channels=model_config["io_channels"],
num_layers=model_config["depth"],
attention_head_dim=attention_head_dim,
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim,
num_attention_heads=model_config["num_heads"],
out_channels=model_config["io_channels"],
cross_attention_dim=model_config["cond_token_dim"],
time_proj_dim=256,
global_states_input_dim=model_config["global_cond_dim"],
cross_attention_input_dim=model_config["cond_token_dim"],
)
if is_accelerate_available():
load_model_dict_into_meta(model, model_state_dict)
else:
model.load_state_dict(model_state_dict)
autoencoder_config = config_dict["model"]["pretransform"]["config"]
with ctx():
autoencoder = AutoencoderOobleck(
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"],
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"],
decoder_channels=autoencoder_config["decoder"]["config"]["channels"],
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"],
audio_channels=autoencoder_config["io_channels"],
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"],
sampling_rate=config_dict["sample_rate"],
)
if is_accelerate_available():
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
else:
autoencoder.load_state_dict(autoencoder_state_dict)
# Prior pipeline
pipeline = StableAudioPipeline(
transformer=model,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=scheduler,
vae=autoencoder,
projection_model=projection_model,
)
pipeline.to(dtype).save_pretrained(
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant
)
+2 -32
View File
@@ -78,16 +78,12 @@ else:
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
@@ -104,7 +100,6 @@ else:
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
"T2IAdapter",
"T5FilmDecoder",
@@ -156,8 +151,6 @@ else:
[
"AmusedScheduler",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
@@ -217,7 +210,7 @@ except OptionalDependencyNotAvailable:
]
else:
_import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
_import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
try:
if not (is_torch_available() and is_transformers_available()):
@@ -237,8 +230,6 @@ else:
"AmusedImg2ImgPipeline",
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffPAGPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
@@ -253,11 +244,8 @@ else:
"ChatGLMModel",
"ChatGLMTokenizer",
"CLIPImageProjection",
"CogVideoXPipeline",
"CycleDiffusionPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
@@ -301,13 +289,10 @@ else:
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
@@ -529,16 +514,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
@@ -555,7 +536,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3MultiControlNetModel,
SD3Transformer2DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
@@ -604,8 +584,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
@@ -654,7 +632,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
from .schedulers import DPMSolverSDEScheduler
try:
if not (is_torch_available() and is_transformers_available()):
@@ -668,8 +646,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AmusedImg2ImgPipeline,
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffControlNetPipeline,
AnimateDiffPAGPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
@@ -682,11 +658,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ChatGLMModel,
ChatGLMTokenizer,
CLIPImageProjection,
CogVideoXPipeline,
CycleDiffusionPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
@@ -730,13 +703,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
-2
View File
@@ -66,7 +66,6 @@ if is_torch_available():
"SD3LoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -84,7 +83,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import IPAdapterMixin
from .lora_pipeline import (
AmusedLoraLoaderMixin,
FluxLoraLoaderMixin,
LoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
-475
View File
@@ -1475,481 +1475,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components)
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Specific to [`StableDiffusion3Pipeline`].
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = cls._fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
return state_dict
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=None,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
def load_lora_into_text_encoder(
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
adapter_name=None,
_pipeline=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = cls.text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components)
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
-1
View File
@@ -32,7 +32,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
}
-10
View File
@@ -28,9 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
@@ -42,7 +40,6 @@ if is_torch_available():
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
@@ -50,10 +47,8 @@ if is_torch_available():
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -79,9 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
@@ -95,17 +88,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modeling_utils import ModelMixin
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
HunyuanDiT2DModel,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
-22
View File
@@ -123,28 +123,6 @@ class GEGLU(nn.Module):
return hidden_states * self.gelu(gate)
class SwiGLU(nn.Module):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
but uses SiLU / Swish instead of GeLU.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.activation(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
+1 -3
View File
@@ -19,7 +19,7 @@ from torch import nn
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
@@ -820,8 +820,6 @@ class FeedForward(nn.Module):
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
+6 -582
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
@@ -49,10 +49,6 @@ class Attention(nn.Module):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
kv_heads (`int`, *optional*, defaults to `None`):
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
Query Attention (MQA) otherwise GQA is used.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
@@ -121,12 +117,11 @@ class Attention(nn.Module):
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
context_pre_only=None,
pre_only=False,
):
super().__init__()
# To prevent circular import.
from .normalization import FP32LayerNorm, RMSNorm
from .normalization import FP32LayerNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -142,7 +137,6 @@ class Attention(nn.Module):
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -188,9 +182,6 @@ class Attention(nn.Module):
# Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
@@ -233,10 +224,9 @@ class Attention(nn.Module):
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
@@ -245,9 +235,6 @@ class Attention(nn.Module):
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
self.norm_added_q = None
self.norm_added_k = None
@@ -539,7 +526,7 @@ class Attention(nn.Module):
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
) -> torch.Tensor:
r"""
Compute the attention scores.
@@ -1274,179 +1261,6 @@ class AuraFlowAttnProcessor2_0:
return hidden_states
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class FluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
@@ -1785,142 +1599,6 @@ class AttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class StableAudioAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def apply_partial_rotary_emb(
self,
x: torch.Tensor,
freqs_cis: Tuple[torch.Tensor],
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
rot_dim = freqs_cis[0].shape[-1]
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
out = torch.cat((x_rotated, x_unrotated), dim=-1)
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = query.shape[-1] // attn.heads
kv_heads = key.shape[-1] // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if rotary_emb is not None:
query_dtype = query.dtype
key_dtype = key.dtype
query = query.to(torch.float32)
key = key.to(torch.float32)
rot_dim = rotary_emb[0].shape[-1]
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
query = torch.cat((query_rotated, query_unrotated), dim=-1)
if not attn.is_cross_attention:
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = torch.cat((key_rotated, key_unrotated), dim=-1)
query = query.to(query_dtype)
key = key.to(key_dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
@@ -2147,253 +1825,6 @@ class FusedHunyuanAttnProcessor2_0:
return hidden_states
class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
@@ -2566,11 +1997,6 @@ class FusedAttnProcessor2_0:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
@@ -3715,6 +3141,4 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
]
@@ -1,8 +1,6 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel
File diff suppressed because it is too large Load Diff
@@ -1,464 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.alpha.requires_grad = True
self.beta.requires_grad = True
self.logscale = logscale
def forward(self, hidden_states):
shape = hidden_states.shape
alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class OobleckResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state):
"""
Forward pass through the residual unit.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor .
Returns:
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`)
Input tensor after passing through the residual unit.
"""
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class OobleckEncoderBlock(nn.Module):
"""Encoder block used in Oobleck encoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class OobleckDecoderBlock(nn.Module):
"""Decoder block used in Oobleck decoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
def mode(self) -> torch.Tensor:
return self.mean
@dataclass
class AutoencoderOobleckOutput(BaseOutput):
"""
Output of AutoencoderOobleck encoding method.
Args:
latent_dist (`OobleckDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and standard deviation of
`OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
from the distribution.
"""
latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
@dataclass
class OobleckDecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.Tensor
class OobleckEncoder(nn.Module):
"""Oobleck Encoder"""
def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
super().__init__()
strides = downsampling_ratios
channel_multiples = [1] + channel_multiples
# Create first convolution
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
self.block += [
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
]
self.block = nn.ModuleList(self.block)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class OobleckDecoder(nn.Module):
"""Oobleck Decoder"""
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
super().__init__()
strides = upsampling_ratios
channel_multiples = [1] + channel_multiples
# Add first conv layer
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(strides) - stride_index],
output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
stride=stride,
)
]
self.block = nn.ModuleList(block)
output_dim = channels
self.snake1 = Snake1d(output_dim)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class AutoencoderOobleck(ModelMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
encoder_hidden_size (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the encoder.
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
Multiples used to determine the hidden sizes of the hidden layers.
decoder_channels (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the decoder.
decoder_input_channels (`int`, *optional*, defaults to 64):
Input dimension for the decoder. Corresponds to the latent dimension.
audio_channels (`int`, *optional*, defaults to 2):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 44100):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
encoder_hidden_size=128,
downsampling_ratios=[2, 4, 4, 8, 8],
channel_multiples=[1, 2, 4, 8, 16],
decoder_channels=128,
decoder_input_channels=64,
audio_channels=2,
sampling_rate=44100,
):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_channels = decoder_channels
self.upsampling_ratios = downsampling_ratios[::-1]
self.hop_length = int(np.prod(downsampling_ratios))
self.sampling_rate = sampling_rate
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=self.upsampling_ratios,
channel_multiples=channel_multiples,
)
self.use_slicing = False
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
posterior = OobleckDiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderOobleckOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
dec = self.decoder(z)
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.OobleckDecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return OobleckDecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[OobleckDecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
@@ -32,7 +32,10 @@ from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unets.unet_2d_condition import UNet2DConditionModel
from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion
from .unets.unet_3d_blocks import (
CrossAttnDownBlockMotion,
DownBlockMotion,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -314,6 +317,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
temporal_double_self_attention=False,
)
elif down_block_type == "DownBlockMotion":
down_block = DownBlockMotion(
@@ -330,6 +334,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
temporal_num_attention_heads=motion_num_attention_heads[i],
temporal_max_seq_length=motion_max_seq_length,
temporal_double_self_attention=False,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i],
)
else:
-68
View File
@@ -285,74 +285,6 @@ class KDownsample2D(nn.Module):
return F.conv2d(inputs, weight, stride=2)
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `2`):
Stride of the convolution.
padding (`int`, defaults to `0`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = F.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
def downsample_2d(
hidden_states: torch.Tensor,
kernel: Optional[torch.Tensor] = None,
+6 -140
View File
@@ -78,53 +78,6 @@ def get_timestep_embedding(
return emb
def get_3d_sincos_pos_embed(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Args:
embed_dim (`int`):
spatial_size (`int` or `Tuple[int, int]`):
temporal_size (`int`):
spatial_interpolation_scale (`float`, defaults to 1.0):
temporal_interpolation_scale (`float`, defaults to 1.0):
"""
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4
# 1. Spatial
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# 2. Temporal
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# 3. Concat
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
return pos_embed
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
@@ -334,46 +287,6 @@ class LuminaPatchEmbed(nn.Module):
)
class CogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
return embeds
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
@@ -389,7 +302,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
`torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
@@ -439,13 +352,7 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -465,9 +372,6 @@ def get_1d_rotary_pos_embed(
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
@@ -479,14 +383,10 @@ def get_1d_rotary_pos_embed(
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
@@ -496,7 +396,6 @@ def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -518,17 +417,8 @@ def apply_rotary_emb(
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Use for example in Lumina
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
@@ -882,30 +772,6 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
return conditioning
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
time_guidance_emb = timesteps_emb + guidance_emb
pooled_projections = self.text_embedder(pooled_projection)
conditioning = time_guidance_emb + pooled_projections
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
@@ -989,7 +855,7 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
if self.use_style_cond_and_image_meta_size:
# extra condition2: image meta size embedding
# extra condition2: image meta size embdding
image_meta_size = self.size_proj(image_meta_size.view(-1))
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
+7 -91
View File
@@ -37,44 +37,16 @@ class AdaLayerNorm(nn.Module):
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(
self,
embedding_dim: int,
num_embeddings: Optional[int] = None,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()
self.chunk_dim = chunk_dim
if num_embeddings is not None:
self.emb = nn.Embedding(num_embeddings, embedding_dim)
else:
self.emb = None
output_dim = output_dim or embedding_dim * 2
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.emb is not None:
temb = self.emb(timestep)
temb = self.linear(self.silu(temb))
if self.chunk_dim == 1:
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
else:
scale, shift = temb.chunk(2, dim=0)
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
@@ -134,38 +106,6 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
@@ -349,30 +289,6 @@ class LuminaLayerNormContinuous(nn.Module):
return x
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
@@ -3,7 +3,6 @@ from ...utils import is_torch_available
if is_torch_available():
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
@@ -11,9 +10,7 @@ if is_torch_available():
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
@@ -1,352 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in CogVideoX model. TODO: add link to CogVideoX upon release
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
text_length = norm_encoder_hidden_states.size(1)
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
# them in cross-attention individually
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in CogVideoX. TODO: add link to CogVideoX upon release
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input.
out_channels (`int`, *optional*):
The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
patch_size (`int`, *optional*):
The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states. During inference, you can denoise for up to but not more steps than
`num_embeds_ada_norm`.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
caption_channels (`int`, *optional*):
The number of channels in the caption embeddings.
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: Optional[int] = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
hidden_states = self.norm_final(hidden_states)
# 6. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
import torch
from torch import nn
@@ -19,7 +19,6 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -187,66 +186,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
@@ -1,458 +0,0 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
StableAudioAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableAudioGaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
# Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
def __init__(
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.log = log
self.flip_sin_to_cos = flip_sin_to_cos
if set_W_to_weight:
# to delete later
del self.weight
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
self.weight = self.W
del self.W
def forward(self, x):
if self.log:
x = torch.log(x)
x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
@maybe_allow_in_graph
class StableAudioDiTBlock(nn.Module):
r"""
Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
connection and QKNorm
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for the query states.
num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_key_value_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
upcast_attention: bool = False,
norm_eps: float = 1e-5,
ff_inner_dim: Optional[int] = None,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
out_bias=False,
processor=StableAudioAttnProcessor2_0(),
)
# 2. Cross-Attn
self.norm2 = nn.LayerNorm(dim, norm_eps, True)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_key_value_attention_heads,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
out_bias=False,
processor=StableAudioAttnProcessor2_0(),
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn="swiglu",
final_dropout=False,
inner_dim=ff_inner_dim,
bias=True,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
rotary_embedding: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
rotary_emb=rotary_embedding,
)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
return hidden_states
class StableAudioDiTModel(ModelMixin, ConfigMixin):
"""
The Diffusion Transformer model introduced in Stable Audio.
Reference: https://github.com/Stability-AI/stable-audio-tools
Parameters:
sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
num_key_value_attention_heads (`int`, *optional*, defaults to 12):
The number of heads to use for the key and value states.
out_channels (`int`, defaults to 64): Number of output channels.
cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
global_states_input_dim ( `int`, *optional*, defaults to 1536):
Input dimension of the global hidden states projection.
cross_attention_input_dim ( `int`, *optional*, defaults to 768):
Input dimension of the cross-attention projection
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 1024,
in_channels: int = 64,
num_layers: int = 24,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
num_key_value_attention_heads: int = 12,
out_channels: int = 64,
cross_attention_dim: int = 768,
time_proj_dim: int = 256,
global_states_input_dim: int = 1536,
cross_attention_input_dim: int = 768,
):
super().__init__()
self.sample_size = sample_size
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.time_proj = StableAudioGaussianFourierProjection(
embedding_size=time_proj_dim // 2,
flip_sin_to_cos=True,
log=False,
set_W_to_weight=False,
)
self.timestep_proj = nn.Sequential(
nn.Linear(time_proj_dim, self.inner_dim, bias=True),
nn.SiLU(),
nn.Linear(self.inner_dim, self.inner_dim, bias=True),
)
self.global_proj = nn.Sequential(
nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
nn.SiLU(),
nn.Linear(self.inner_dim, self.inner_dim, bias=False),
)
self.cross_attention_proj = nn.Sequential(
nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
nn.SiLU(),
nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
)
self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
self.transformer_blocks = nn.ModuleList(
[
StableAudioDiTBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
num_key_value_attention_heads=num_key_value_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for i in range(num_layers)
]
)
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(StableAudioAttnProcessor2_0())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
timestep: torch.LongTensor = None,
encoder_hidden_states: torch.FloatTensor = None,
global_hidden_states: torch.FloatTensor = None,
rotary_embedding: torch.FloatTensor = None,
return_dict: bool = True,
attention_mask: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`StableAudioDiTModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
Input `hidden_states`.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
Global embeddings that will be prepended to the hidden states.
rotary_embedding (`torch.Tensor`):
The rotary embeddings to apply on query and key tensors during attention calculation.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
Mask to avoid performing attention on padding token indices, formed by concatenating the attention
masks
for the two text encoders together. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
the attention masks
for the two text encoders together. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
global_hidden_states = self.global_proj(global_hidden_states)
time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype)))
global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
hidden_states = self.preprocess_conv(hidden_states) + hidden_states
# (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.proj_in(hidden_states)
# prepend global states to hidden states
hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
if attention_mask is not None:
prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
cross_attention_hidden_states,
encoder_attention_mask,
rotary_embedding,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=cross_attention_hidden_states,
encoder_attention_mask=encoder_attention_mask,
rotary_embedding=rotary_embedding,
)
hidden_states = self.proj_out(hidden_states)
# (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
# remove prepend length that has been added by global hidden states
hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
hidden_states = self.postprocess_conv(hidden_states) + hidden_states
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
@@ -1,451 +0,0 @@
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi to-do: refactor rope related functions/classes
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
# YiYi to-do: refactor rope related functions/classes
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxSingleAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-64
View File
@@ -348,70 +348,6 @@ class KUpsample2D(nn.Module):
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
Args:
in_channels (`int`):
Number of channels in the input image.
out_channels (`int`):
Number of channels produced by the convolution.
kernel_size (`int`, defaults to `3`):
Size of the convolving kernel.
stride (`int`, defaults to `1`):
Stride of the convolution.
padding (`int`, defaults to `1`):
Padding added to all four sides of the input.
compress_time (`bool`, defaults to `False`):
Whether or not to compress the time dimension.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = F.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = F.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
def upfirdn2d_native(
tensor: torch.Tensor,
kernel: torch.Tensor,
+1 -1
View File
@@ -87,7 +87,7 @@ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_
The optimizer for which to schedule the learning rate.
step_rules (`string`):
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
steps and multiple 0.005 for the other steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
-17
View File
@@ -118,12 +118,10 @@ else:
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
_import_structure["animatediff"] = [
"AnimateDiffPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["flux"] = ["FluxPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -131,7 +129,6 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -145,15 +142,12 @@ else:
)
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline",
]
)
_import_structure["controlnet_xs"].extend(
@@ -237,10 +231,6 @@ else:
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
"StableAudioProjectionModel",
"StableAudioPipeline",
]
_import_structure["stable_cascade"] = [
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
@@ -425,7 +415,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
from .animatediff import (
AnimateDiffControlNetPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
@@ -439,7 +428,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -481,7 +469,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -534,9 +521,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -549,7 +533,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
from .stable_cascade import (
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
@@ -22,7 +22,6 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
_import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"]
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
@@ -36,7 +35,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_animatediff import AnimateDiffPipeline
from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
File diff suppressed because it is too large Load Diff
-6
View File
@@ -28,7 +28,6 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -50,8 +49,6 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
@@ -86,7 +83,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
@@ -101,10 +97,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("kolors", KolorsPipeline),
("flux", FluxPipeline),
]
)
@@ -1,48 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,686 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.bfloat16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> video = pipe(
... "a polar bear dancing, high quality, realistic", guidance_scale=6, num_inference_steps=20
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
class CogVideoXPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = []
for i in range(num_seconds):
# Whether or not to clear fake context parallel cache
fake_cp = i + 1 < num_seconds
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame], fake_cp=fake_cp).sample
frames.append(current_frames)
frames = torch.cat(frames, dim=2)
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_frames: int = 48,
fps: int = 8,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
use_dynamic_cfg: bool = False,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
assert (
num_frames <= 48 and num_frames % fps == 0 and fps == 8
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
num_frames += 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latents":
video = self.decode_latents(latents, num_frames // fps)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
@@ -1272,7 +1272,7 @@ class StableDiffusionControlNetPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1244,7 +1244,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1408,7 +1408,7 @@ class StableDiffusionControlNetInpaintPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1739,7 +1739,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1487,7 +1487,7 @@ class StableDiffusionXLControlNetPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1551,7 +1551,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
-47
View File
@@ -1,47 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_flux"] = ["FluxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_flux import FluxPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
@@ -1,749 +0,0 @@
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import FluxPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxPipeline
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image.save("flux.png")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=device)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
noise_pred = self.transformer(
hidden_states=latents,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class FluxPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
-6
View File
@@ -24,10 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -42,10 +39,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
+123 -100
View File
@@ -12,15 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
from ...models.attention_processor import (
Attention,
AttentionProcessor,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
)
@@ -31,56 +25,123 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PAGMixin:
r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1)."""
r"""Mixin class for PAG."""
@staticmethod
def _check_input_pag_applied_layer(layer):
r"""
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
in the format of "attentions_{j}".
"""
layer_splits = layer.split(".")
if len(layer_splits) > 3:
raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.")
if len(layer_splits) >= 1:
if layer_splits[0] not in ["down", "mid", "up"]:
raise ValueError(
f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}"
)
if len(layer_splits) >= 2:
if not layer_splits[1].startswith("block_"):
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
if len(layer_splits) == 3:
if not layer_splits[2].startswith("attentions_"):
raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
pag_attn_processors = self._pag_attn_processors
if pag_attn_processors is None:
raise ValueError(
"No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters."
)
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
if hasattr(self, "unet"):
model: nn.Module = self.unet
if do_classifier_free_guidance:
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
else:
model: nn.Module = self.transformer
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
def is_self_attn(module: nn.Module) -> bool:
def is_self_attn(module_name):
r"""
Check if the module is self-attention module based on its name.
"""
return isinstance(module, Attention) and not module.is_cross_attention
return "attn1" in module_name and "to" not in name
def is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
def get_block_type(module_name):
r"""
Get the block type from the module name. can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]
for layer_id in pag_applied_layers:
def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
if "attentions" in module_name.split(".")[1]:
return "block_0"
else:
return f"block_{module_name.split('.')[1]}"
def get_attn_index(module_name):
r"""
Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
if "attentions" in module_name.split(".")[2]:
return f"attentions_{module_name.split('.')[3]}"
elif "attentions" in module_name.split(".")[1]:
return f"attentions_{module_name.split('.')[2]}"
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
target_modules = []
for name, module in model.named_modules():
# Identify the following simple cases:
# (1) Self Attention layer existing
# (2) Whether the module name matches pag layer id even partially
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
if (
is_self_attn(module)
and re.search(layer_id, name) is not None
and not is_fake_integral_match(layer_id, name)
):
logger.debug(f"Applying PAG to layer: {name}")
target_modules.append(module)
pag_layer_input_splits = pag_layer_input.split(".")
if len(pag_layer_input_splits) == 1:
# when the layer input only contains block_type. e.g. "mid", "down", "up"
block_type = pag_layer_input_splits[0]
for name, module in self.unet.named_modules():
if is_self_attn(name) and get_block_type(name) == block_type:
target_modules.append(module)
elif len(pag_layer_input_splits) == 2:
# when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
):
target_modules.append(module)
elif len(pag_layer_input_splits) == 3:
# when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
for name, module in self.unet.named_modules():
if (
is_self_attn(name)
and get_block_type(name) == block_type
and get_block_index(name) == block_index
and get_attn_index(name) == attn_index
):
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
for module in target_modules:
module.processor = pag_attn_proc
@@ -143,95 +204,57 @@ class PAGMixin:
cond = torch.cat([uncond, cond], dim=0)
return cond
def set_pag_applied_layers(
self,
pag_applied_layers: Union[str, List[str]],
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0(),
),
):
def set_pag_applied_layers(self, pag_applied_layers):
r"""
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
Args:
pag_applied_layers (`str` or `List[str]`):
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
PAG is to be applied. A few ways of expected usage are as follows:
- Single layers specified as - "blocks.{layer_index}"
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
- Multiple layers as a block name - "mid"
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
pag_attn_processors:
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
attention processor is for PAG with CFG disabled (unconditional only).
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
"""
if not hasattr(self, "_pag_attn_processors"):
self._pag_attn_processors = None
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
if pag_attn_processors is not None:
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
raise ValueError("Expected a tuple of two attention processors")
for i in range(len(pag_applied_layers)):
if not isinstance(pag_applied_layers[i], str):
raise ValueError(
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
)
for pag_layer in pag_applied_layers:
self._check_input_pag_applied_layer(pag_layer)
self.pag_applied_layers = pag_applied_layers
self._pag_attn_processors = pag_attn_processors
@property
def pag_scale(self) -> float:
r"""Get the scale factor for the perturbed attention guidance."""
def pag_scale(self):
"""
Get the scale factor for the perturbed attention guidance.
"""
return self._pag_scale
@property
def pag_adaptive_scale(self) -> float:
r"""Get the adaptive scale factor for the perturbed attention guidance."""
def pag_adaptive_scale(self):
"""
Get the adaptive scale factor for the perturbed attention guidance.
"""
return self._pag_adaptive_scale
@property
def do_pag_adaptive_scaling(self) -> bool:
r"""Check if the adaptive scaling is enabled for the perturbed attention guidance."""
def do_pag_adaptive_scaling(self):
"""
Check if the adaptive scaling is enabled for the perturbed attention guidance.
"""
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def do_perturbed_attention_guidance(self) -> bool:
r"""Check if the perturbed attention guidance is enabled."""
def do_perturbed_attention_guidance(self):
"""
Check if the perturbed attention guidance is enabled.
"""
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
def pag_attn_processors(self) -> Dict[str, AttentionProcessor]:
def pag_attn_processors(self):
r"""
Returns:
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
with the key as the name of the layer.
"""
if self._pag_attn_processors is None:
return {}
valid_attn_processors = {x.__class__ for x in self._pag_attn_processors}
processors = {}
# We could have iterated through the self.components.items() and checked if a component is
# `ModelMixin` subclassed but that can include a VAE too.
if hasattr(self, "unet"):
denoiser_module = self.unet
elif hasattr(self, "transformer"):
denoiser_module = self.transformer
else:
raise ValueError("No denoiser module found.")
for name, proc in denoiser_module.attn_processors.items():
if proc.__class__ in valid_attn_processors:
for name, proc in self.unet.attn_processors.items():
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors
@@ -1249,7 +1249,7 @@ class StableDiffusionControlNetPAGPipeline(
)
if guess_mode and self.do_classifier_free_guidance:
# Inferred ControlNet only for the conditional batch.
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
@@ -1,953 +0,0 @@
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HunyuanDiT2DModel
from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pag_utils import PAGMixin
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import AutoPipelineForText2Image
>>> pipe = AutoPipelineForText2Image.from_pretrained(
... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
... torch_dtype=torch.float16,
... enable_pag=True,
... pag_applied_layers=[14],
... ).to("cuda")
>>> # prompt = "an astronaut riding a horse"
>>> prompt = "一个宇航员在骑马"
>>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0]
```
"""
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]
def map_to_standard_shapes(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def get_resize_crop_region_for_grid(src, tgt_size):
th = tw = tgt_size
h, w = src
r = h / w
# resize
if r > 1:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
r"""
Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention
Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
ourselves)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
`sdxl-vae-fp16-fix`.
text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
HunyuanDiT uses a fine-tuned [bilingual CLIP].
tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
transformer ([`HunyuanDiT2DModel`]):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = [
"safety_checker",
"feature_extractor",
"text_encoder_2",
"tokenizer_2",
"text_encoder",
"tokenizer",
]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: BertModel,
tokenizer: BertTokenizer,
transformer: HunyuanDiT2DModel,
scheduler: DDPMScheduler,
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
requires_safety_checker: bool = True,
text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2: Optional[MT5Tokenizer] = None,
pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
text_encoder_2=text_encoder_2,
)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
self.set_pag_applied_layers(
pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0())
)
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt
def encode_prompt(
self,
prompt: str,
device: torch.device = None,
dtype: torch.dtype = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: Optional[int] = None,
text_encoder_index: int = 0,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
if device is None:
device = self._execution_device
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = tokenizers[text_encoder_index]
text_encoder = text_encoders[text_encoder_index]
if max_sequence_length is None:
if text_encoder_index == 0:
max_length = 77
if text_encoder_index == 1:
max_length = 256
else:
max_length = max_sequence_length
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
attention_mask=negative_prompt_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = True,
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
pag_scale (`float`, *optional*, defaults to 3.0):
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
guidance will not be used.
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
used.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int((height // 16) * 16)
width = int((width // 16) * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=77,
text_encoder_index=0,
)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=self.transformer.dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask_2,
max_sequence_length=256,
text_encoder_index=1,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_perturbed_attention_guidance:
prompt_embeds = self._prepare_perturbed_attention_guidance(
prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
)
prompt_attention_mask = self._prepare_perturbed_attention_guidance(
prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance
)
prompt_embeds_2 = self._prepare_perturbed_attention_guidance(
prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance
)
prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance(
prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance
)
add_time_ids = torch.cat([add_time_ids] * 3, dim=0)
style = torch.cat([style] * 3, dim=0)
elif self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.do_perturbed_attention_guidance:
original_attn_proc = self.transformer.attn_processors
self._set_pag_attn_processor(
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# 9. Offload all models
self.maybe_free_model_hooks()
if self.do_perturbed_attention_guidance:
self.transformer.set_attn_processor(original_attn_proc)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -1,872 +0,0 @@
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import PixArtImageProcessor
from ...models import AutoencoderKL, PixArtTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
)
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PAGMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AutoPipelineForText2Image
>>> pipe = AutoPipelineForText2Image.from_pretrained(
... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
... torch_dtype=torch.float16,
... pag_applied_layers=["blocks.14"],
... enable_pag=True,
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "A small cactus with a happy face in the Sahara desert"
>>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0]
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
r"""
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
using PixArt-Sigma.
"""
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: PixArtTransformer2DModel,
scheduler: KarrasDiffusionSchedulers,
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 300,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because T5 can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 300,
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`.
pag_scale (`float`, *optional*, defaults to 3.0):
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
guidance will not be used.
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
used.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 256:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if self.do_perturbed_attention_guidance:
prompt_embeds = self._prepare_perturbed_attention_guidance(
prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance
)
prompt_attention_mask = self._prepare_perturbed_attention_guidance(
prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance
)
elif do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if self.do_perturbed_attention_guidance:
original_attn_proc = self.transformer.attn_processors
self._set_pag_attn_processor(
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=do_classifier_free_guidance,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep
)
elif do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if self.do_perturbed_attention_guidance:
self.transformer.set_attn_processor(original_attn_proc)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -1,846 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..animatediff.pipeline_output import AnimateDiffPipelineOutput
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pag_utils import PAGMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler
>>> from diffusers.utils import export_to_gif
>>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
>>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2"
>>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id)
>>> scheduler = DDIMScheduler.from_pretrained(
... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False
... )
>>> pipe = AnimateDiffPAGPipeline.from_pretrained(
... model_id,
... motion_adapter=motion_adapter,
... scheduler=scheduler,
... pag_applied_layers=["mid"],
... torch_dtype=torch.float16,
... ).to("cuda")
>>> video = pipe(
... prompt="car, futuristic cityscape with neon lights, street, no human",
... negative_prompt="low quality, bad quality",
... num_inference_steps=25,
... guidance_scale=6.0,
... pag_scale=3.0,
... generator=torch.Generator().manual_seed(42),
... ).frames[0]
>>> export_to_gif(video, "animatediff_pag.gif")
```
"""
class AnimateDiffPAGPipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
PAGMixin,
):
r"""
Pipeline for text-to-video generation using
[AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention
Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
A [`~transformers.CLIPTokenizer`] to tokenize text.
unet ([`UNet2DConditionModel`]):
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
motion_adapter ([`MotionAdapter`]):
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter,
scheduler: KarrasDiffusionSchedulers,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"]
):
super().__init__()
if isinstance(unet, UNet2DConditionModel):
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_frames (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
amounts to 2 seconds of video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
pag_scale (`float`, *optional*, defaults to 3.0):
The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention
guidance will not be used.
pag_adaptive_scale (`float`, *optional*, defaults to 0.0):
The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is
used.
Examples:
Returns:
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_perturbed_attention_guidance:
prompt_embeds = self._prepare_perturbed_attention_guidance(
prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance
)
elif self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_videos_per_prompt,
self.do_classifier_free_guidance,
)
for i, image_embeds in enumerate(ip_adapter_image_embeds):
negative_image_embeds = None
if self.do_classifier_free_guidance:
negative_image_embeds, image_embeds = image_embeds.chunk(2)
if self.do_perturbed_attention_guidance:
image_embeds = self._prepare_perturbed_attention_guidance(
image_embeds, negative_image_embeds, self.do_classifier_free_guidance
)
elif self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
image_embeds = image_embeds.to(device)
ip_adapter_image_embeds[i] = image_embeds
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": ip_adapter_image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)
if self.do_perturbed_attention_guidance:
original_attn_proc = self.unet.attn_processors
self._set_pag_attn_processor(
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
for free_init_iter in range(num_free_init_iters):
if self.free_init_enabled:
latents, timesteps = self._apply_free_init(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post processing
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()
if self.do_perturbed_attention_guidance:
self.unet.set_attn_processor(original_attn_proc)
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
+8 -7
View File
@@ -54,21 +54,22 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline
>>> from diffusers import (
... EulerDiscreteScheduler,
... MotionAdapter,
... PIAPipeline,
... )
>>> from diffusers.utils import export_to_gif, load_image
>>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
>>> pipe = PIAPipeline.from_pretrained(
... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
... )
>>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers")
>>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter)
>>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
>>> image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
... )
>>> image = image.resize((512, 512))
>>> prompt = "cat in a hat"
>>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted"
>>> negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
>>> generator = torch.Generator("cpu").manual_seed(0)
>>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
>>> frames = output.frames[0]
@@ -1,50 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"]
_import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_stable_audio import StableAudioProjectionModel
from .pipeline_stable_audio import StableAudioPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,158 +0,0 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import pi
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, times: torch.Tensor) -> torch.Tensor:
times = times[..., None]
freqs = times * self.weights[None] * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered
@dataclass
class StableAudioProjectionModelOutput(BaseOutput):
"""
Args:
Class for StableAudio projection layer's outputs.
text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder.
seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the audio start hidden states.
seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
Sequence of hidden-states obtained by linearly projecting the audio end hidden states.
"""
text_hidden_states: Optional[torch.Tensor] = None
seconds_start_hidden_states: Optional[torch.Tensor] = None
seconds_end_hidden_states: Optional[torch.Tensor] = None
class StableAudioNumberConditioner(nn.Module):
"""
A simple linear projection model to map numbers to a latent space.
Args:
number_embedding_dim (`int`):
Dimensionality of the number embeddings.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
internal_dim (`int`):
Dimensionality of the intermediate number hidden states.
"""
def __init__(
self,
number_embedding_dim,
min_value,
max_value,
internal_dim: Optional[int] = 256,
):
super().__init__()
self.time_positional_embedding = nn.Sequential(
StableAudioPositionalEmbedding(internal_dim),
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
)
self.number_embedding_dim = number_embedding_dim
self.min_value = min_value
self.max_value = max_value
def forward(
self,
floats: torch.Tensor,
):
floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
# Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
embedding = self.time_positional_embedding(normalized_floats)
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
return float_embeds
class StableAudioProjectionModel(ModelMixin, ConfigMixin):
"""
A simple linear projection model to map the conditioning values to a shared latent space.
Args:
text_encoder_dim (`int`):
Dimensionality of the text embeddings from the text encoder (T5).
conditioning_dim (`int`):
Dimensionality of the output conditioning tensors.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
"""
@register_to_config
def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value):
super().__init__()
self.text_projection = (
nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim)
)
self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
def forward(
self,
text_hidden_states: Optional[torch.Tensor] = None,
start_seconds: Optional[torch.Tensor] = None,
end_seconds: Optional[torch.Tensor] = None,
):
text_hidden_states = (
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
)
seconds_start_hidden_states = (
start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds)
)
seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds)
return StableAudioProjectionModelOutput(
text_hidden_states=text_hidden_states,
seconds_start_hidden_states=seconds_start_hidden_states,
seconds_end_hidden_states=seconds_end_hidden_states,
)
@@ -1,745 +0,0 @@
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
from transformers import (
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from ...models import AutoencoderOobleck, StableAudioDiTModel
from ...models.embeddings import get_1d_rotary_pos_embed
from ...schedulers import EDMDPMSolverMultistepScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_stable_audio import StableAudioProjectionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import scipy
>>> import torch
>>> import soundfile as sf
>>> from diffusers import StableAudioPipeline
>>> repo_id = "stabilityai/stable-audio-open-1.0"
>>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> # define the prompts
>>> prompt = "The sound of a hammer hitting a wooden surface."
>>> negative_prompt = "Low quality."
>>> # set the seed for generator
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # run the generation
>>> audio = pipe(
... prompt,
... negative_prompt=negative_prompt,
... num_inference_steps=200,
... audio_end_in_s=10.0,
... num_waveforms_per_prompt=3,
... generator=generator,
... ).audios
>>> output = audio[0].T.float().cpu().numpy()
>>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
```
"""
class StableAudioPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-audio generation using StableAudio.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderOobleck`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.T5EncoderModel`]):
Frozen text-encoder. StableAudio uses the encoder of
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
projection_model ([`StableAudioProjectionModel`]):
A trained model used to linearly project the hidden-states from the text encoder model and the start and
end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
give the input to the transformer model.
tokenizer ([`~transformers.T5Tokenizer`]):
Tokenizer to tokenize text for the frozen text-encoder.
transformer ([`StableAudioDiTModel`]):
A `StableAudioDiTModel` to denoise the encoded audio latents.
scheduler ([`EDMDPMSolverMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
"""
model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
def __init__(
self,
vae: AutoencoderOobleck,
text_encoder: T5EncoderModel,
projection_model: StableAudioProjectionModel,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
transformer: StableAudioDiTModel,
scheduler: EDMDPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
projection_model=projection_model,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# 1. Tokenize text
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids.to(device)
attention_mask = attention_mask.to(device)
# 2. Text encoder forward
self.text_encoder.eval()
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
if do_classifier_free_guidance and negative_prompt is not None:
uncond_tokens: List[str]
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# 1. Tokenize text
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
# 2. Text encoder forward
self.text_encoder.eval()
negative_prompt_embeds = self.text_encoder(
uncond_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if negative_attention_mask is not None:
# set the masked tokens to the null embed
negative_prompt_embeds = torch.where(
negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
)
# 3. Project prompt_embeds and negative_prompt_embeds
if do_classifier_free_guidance and negative_prompt_embeds is not None:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the negative and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if attention_mask is not None and negative_attention_mask is None:
negative_attention_mask = torch.ones_like(attention_mask)
elif attention_mask is None and negative_attention_mask is not None:
attention_mask = torch.ones_like(negative_attention_mask)
if attention_mask is not None:
attention_mask = torch.cat([negative_attention_mask, attention_mask])
prompt_embeds = self.projection_model(
text_hidden_states=prompt_embeds,
).text_hidden_states
if attention_mask is not None:
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
return prompt_embeds
def encode_duration(
self,
audio_start_in_s,
audio_end_in_s,
device,
do_classifier_free_guidance,
batch_size,
):
audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
if len(audio_start_in_s) == 1:
audio_start_in_s = audio_start_in_s * batch_size
if len(audio_end_in_s) == 1:
audio_end_in_s = audio_end_in_s * batch_size
# Cast the inputs to floats
audio_start_in_s = [float(x) for x in audio_start_in_s]
audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
audio_end_in_s = [float(x) for x in audio_end_in_s]
audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
projection_output = self.projection_model(
start_seconds=audio_start_in_s,
end_seconds=audio_end_in_s,
)
seconds_start_hidden_states = projection_output.seconds_start_hidden_states
seconds_end_hidden_states = projection_output.seconds_end_hidden_states
# For classifier free guidance, we need to do two forward passes.
# Here we repeat the audio hidden states to avoid doing two forward passes
if do_classifier_free_guidance:
seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
return seconds_start_hidden_states, seconds_end_hidden_states
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
audio_start_in_s,
audio_end_in_s,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
attention_mask=None,
negative_attention_mask=None,
initial_audio_waveforms=None,
initial_audio_sampling_rate=None,
):
if audio_end_in_s < audio_start_in_s:
raise ValueError(
f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
)
if (
audio_start_in_s < self.projection_model.config.min_value
or audio_start_in_s > self.projection_model.config.max_value
):
raise ValueError(
f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
f"is {audio_start_in_s}."
)
if (
audio_end_in_s < self.projection_model.config.min_value
or audio_end_in_s > self.projection_model.config.max_value
):
raise ValueError(
f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
f"is {audio_end_in_s}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and (prompt_embeds is None):
raise ValueError(
"Provide either `prompt`, or `prompt_embeds`. Cannot leave"
"`prompt` undefined without specifying `prompt_embeds`."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
raise ValueError(
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
)
if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
raise ValueError(
"`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
)
if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
raise ValueError(
f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
"Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
)
def prepare_latents(
self,
batch_size,
num_channels_vae,
sample_size,
dtype,
device,
generator,
latents=None,
initial_audio_waveforms=None,
num_waveforms_per_prompt=None,
audio_channels=None,
):
shape = (batch_size, num_channels_vae, sample_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# encode the initial audio for use by the model
if initial_audio_waveforms is not None:
# check dimension
if initial_audio_waveforms.ndim == 2:
initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
elif initial_audio_waveforms.ndim != 3:
raise ValueError(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
# check num_channels
if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
raise ValueError(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
)
# crop or pad
audio_length = initial_audio_waveforms.shape[-1]
if audio_length < audio_vae_length:
logger.warning(
f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
)
elif audio_length > audio_vae_length:
logger.warning(
f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
)
audio = initial_audio_waveforms.new_zeros(audio_shape)
audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
latents = encoded_audio + latents
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
audio_end_in_s: Optional[float] = None,
audio_start_in_s: Optional[float] = 0.0,
num_inference_steps: int = 100,
guidance_scale: float = 7.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_waveforms_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
initial_audio_waveforms: Optional[torch.Tensor] = None,
initial_audio_sampling_rate: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
output_type: Optional[str] = "pt",
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
audio_end_in_s (`float`, *optional*, defaults to 47.55):
Audio end index in seconds.
audio_start_in_s (`float`, *optional*, defaults to 0):
Audio start index in seconds.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
initial_audio_waveforms (`torch.Tensor`, *optional*):
Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
`(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
corresponds to the number of prompts passed to the model.
initial_audio_sampling_rate (`int`, *optional*):
Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
`negative_prompt` input argument.
attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
be computed from `prompt` input argument.
negative_attention_mask (`torch.LongTensor`, *optional*):
Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
model (LDM) output.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated audio.
"""
# 0. Convert audio input length from seconds to latent length
downsample_ratio = self.vae.hop_length
max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
if audio_end_in_s is None:
audio_end_in_s = max_audio_length_in_s
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
waveform_length = int(self.transformer.config.sample_size)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
audio_start_in_s,
audio_end_in_s,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
initial_audio_waveforms,
initial_audio_sampling_rate,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
)
# Encode duration
seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
audio_start_in_s,
audio_end_in_s,
device,
do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
batch_size,
)
# Create text_audio_duration_embeds and audio_duration_embeds
text_audio_duration_embeds = torch.cat(
[prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
)
audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
# In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
# to concatenate it to the embeddings
if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
negative_text_audio_duration_embeds = torch.zeros_like(
text_audio_duration_embeds, device=text_audio_duration_embeds.device
)
text_audio_duration_embeds = torch.cat(
[negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
)
audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
# duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
text_audio_duration_embeds = text_audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
)
audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
audio_duration_embeds = audio_duration_embeds.view(
bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_vae = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_waveforms_per_prompt,
num_channels_vae,
waveform_length,
text_audio_duration_embeds.dtype,
device,
generator,
latents,
initial_audio_waveforms,
num_waveforms_per_prompt,
audio_channels=self.vae.config.audio_channels,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare rotary positional embedding
rotary_embedding = get_1d_rotary_pos_embed(
self.rotary_embed_dim,
latents.shape[2] + audio_duration_embeds.shape[1],
use_real=True,
repeat_interleave_real=False,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t.unsqueeze(0),
encoder_hidden_states=text_audio_duration_embeds,
global_hidden_states=audio_duration_embeds,
rotary_embedding=rotary_embedding,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 9. Post-processing
if not output_type == "latent":
audio = self.vae.decode(latents).sample
else:
return AudioPipelineOutput(audios=latents)
audio = audio[:, :, waveform_start:waveform_end]
if output_type == "np":
audio = audio.cpu().float().numpy()
self.maybe_free_model_hooks()
if not return_dict:
return (audio,)
return AudioPipelineOutput(audios=audio)
-6
View File
@@ -43,14 +43,12 @@ else:
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
_import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
_import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
_import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
@@ -120,7 +118,6 @@ except OptionalDependencyNotAvailable:
_dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects))
else:
_import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -143,14 +140,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
@@ -210,7 +205,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
else:
@@ -1,572 +0,0 @@
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021).
This scheduler was used in Stable Audio Open [1].
[1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
sigma_min (`float`, *optional*, defaults to 0.3):
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
sigma_max (`float`, *optional*, defaults to 500):
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
sigma_data (`float`, *optional*, defaults to 1.0):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
sigma_schedule (`str`, *optional*, defaults to `exponential`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
prediction_type (`str`, defaults to `v_prediction`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
sigma_min: float = 0.3,
sigma_max: float = 500,
sigma_data: float = 1.0,
sigma_schedule: str = "exponential",
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "v_prediction",
rho: float = 7.0,
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
ramp = torch.linspace(0, 1, num_train_timesteps)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
# setable values
self.num_inference_steps = None
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
scaled_sample = sample * c_in
return scaled_sample
def precondition_noise(self, sigma):
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
return sigma.atan() / math.pi * 2
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
if self.config.prediction_type == "epsilon":
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
elif self.config.prediction_type == "v_prediction":
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
else:
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
denoised = c_skip * sample + c_out * model_output
return denoised
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = self.precondition_inputs(sample, sigma)
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = self.config.sigma_min
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# if a noise sampler is used, reinitialise it
self.noise_sampler = None
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
rho = self.config.rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
return alpha_t, sigma_t
def convert_model_output(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
<Tip>
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
</Tip>
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
sigma = self.sigmas[self.step_index]
x0_pred = self.precondition_outputs(sample, model_output, sigma)
return x0_pred
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
# sde-dpmsolver++
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
if self.noise_sampler is None:
seed = None
if generator is not None:
seed = (
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
self.noise_sampler = BrownianTreeNoiseSampler(
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
model_output.device
)
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
@@ -1,450 +0,0 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->CogVideoXDDIM
class CogVideoXDDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
return alphas_bar
class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
"""
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
non-Markovian guidance.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, defaults to `True`):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.0120,
beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
rescale_betas_zero_snr: bool = False,
snr_shift_scale: float = 3.0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Modify: SNR shift following SD3
self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
)
self.timesteps = torch.from_numpy(timesteps).to(device)
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[CogVideoXDDIMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or
`tuple`.
Returns:
[`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim_cogvideox.CogVideoXDDIMSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t
prev_sample = a_t * sample + b_t * pred_original_sample
if not return_dict:
return (prev_sample,)
return CogVideoXDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps

Some files were not shown because too many files have changed in this diff Show More