Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 85b3f08c26 | |||
| 19f3161d94 | |||
| a1fdfca36f | |||
| d1e20be664 | |||
| af3854d6ad | |||
| 9f1936d2fc | |||
| fbca2e0a7a | |||
| 3768d4d77c | |||
| 8ccb619416 | |||
| 0699ac62f0 | |||
| a76f2ad538 | |||
| 7200daa412 | |||
| 3eeaf4e041 | |||
| c583f3b452 | |||
| 12358b986f | |||
| 5eeedd9e33 | |||
| a971c598b5 | |||
| 934d439a42 | |||
| e3f3672f46 | |||
| 87ae330056 | |||
| 1b46c66132 | |||
| 031358988b | |||
| fd35689f25 | |||
| e8c9069d6f | |||
| 766aa50f70 | |||
| c4d2823601 | |||
| 4f8853e481 | |||
| fed88195e3 | |||
| 0de35e4a52 | |||
| 0d81e543a2 | |||
| 3be0ff9056 | |||
| 2764db3194 | |||
| 048d901993 | |||
| cb432c4ebc | |||
| b7b1a30bc4 | |||
| 7e5587a5ac | |||
| dc8da1d449 | |||
| 3dd540171d | |||
| b3b2d30cd8 | |||
| 3bba44d74e | |||
| b1290d3fb8 | |||
| 29a11c2a94 | |||
| cdacd8f1dd | |||
| 470d51c8ed | |||
| d6141205cd | |||
| 4447547eda | |||
| 5222294748 | |||
| c25c46137d | |||
| 3105c710ba | |||
| 58f5f748f4 | |||
| 4f05058bb7 | |||
| 5d4413001b | |||
| 863e741614 | |||
| 24c5e7708b | |||
| cd21b965d1 | |||
| d185b5ed5f | |||
| 709a642827 | |||
| 0a0fe69aa6 | |||
| 124e76ddc6 | |||
| 05b0ec63bc | |||
| 4909b1e3ac | |||
| 052bf3280b | |||
| 80871ac597 | |||
| 6abc66ef28 | |||
| 38efac9f61 | |||
| 4f6399bedd | |||
| 6e1af3a777 | |||
| f22aad6e3a | |||
| ecded50ad5 | |||
| e34d9aa681 | |||
| 8d30d25794 | |||
| 1e0395e791 | |||
| 9141c1f9d5 | |||
| f75b8aa9dd | |||
| 7a24977ce3 | |||
| 74d902eb59 | |||
| d7c4ae619d | |||
| 67ea2b7afa | |||
| a10107f92b | |||
| d0c30cfd37 | |||
| 7c3e7fedcd |
+22
-12
@@ -36,38 +36,42 @@
|
||||
title: Push files to the Hub
|
||||
title: Loading & Hub
|
||||
- sections:
|
||||
- local: using-diffusers/pipeline_overview
|
||||
title: Overview
|
||||
- local: using-diffusers/unconditional_image_generation
|
||||
title: Unconditional image generation
|
||||
- local: using-diffusers/conditional_image_generation
|
||||
title: Text-to-image generation
|
||||
title: Text-to-image
|
||||
- local: using-diffusers/img2img
|
||||
title: Text-guided image-to-image
|
||||
title: Image-to-image
|
||||
- local: using-diffusers/inpaint
|
||||
title: Text-guided image-inpainting
|
||||
title: Inpainting
|
||||
- local: using-diffusers/depth2img
|
||||
title: Text-guided depth-to-image
|
||||
title: Depth-to-image
|
||||
title: Tasks
|
||||
- sections:
|
||||
- local: using-diffusers/textual_inversion_inference
|
||||
title: Textual inversion
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference with multiple GPUs
|
||||
- local: using-diffusers/distilled_sd
|
||||
title: Distilled Stable Diffusion inference
|
||||
- local: using-diffusers/reusing_seeds
|
||||
title: Improve image quality with deterministic generation
|
||||
- local: using-diffusers/control_brightness
|
||||
title: Control image brightness
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompt weighting
|
||||
title: Techniques
|
||||
- sections:
|
||||
- local: using-diffusers/pipeline_overview
|
||||
title: Overview
|
||||
- local: using-diffusers/sdxl
|
||||
title: Stable Diffusion XL
|
||||
- local: using-diffusers/distilled_sd
|
||||
title: Distilled Stable Diffusion inference
|
||||
- local: using-diffusers/reproducibility
|
||||
title: Create reproducible pipelines
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: Community pipelines
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
title: How to contribute a community pipeline
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
title: Stable Diffusion in JAX/Flax
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompt weighting
|
||||
title: Pipelines for Inference
|
||||
- sections:
|
||||
- local: training/overview
|
||||
@@ -105,6 +109,8 @@
|
||||
title: Memory and Speed
|
||||
- local: optimization/torch2.0
|
||||
title: Torch2.0 support
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
title: Stable Diffusion in JAX/Flax
|
||||
- local: optimization/xformers
|
||||
title: xFormers
|
||||
- local: optimization/onnx
|
||||
@@ -190,6 +196,8 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/consistency_models
|
||||
@@ -222,6 +230,8 @@
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: PaintByExample
|
||||
- local: api/pipelines/paradigms
|
||||
|
||||
@@ -46,6 +46,5 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
@@ -0,0 +1,93 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AudioLDM 2
|
||||
|
||||
AudioLDM 2 was proposed in [AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining](https://arxiv.org/abs/2308.05734)
|
||||
by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate
|
||||
text-conditional sound effects, human speech and music.
|
||||
|
||||
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM 2
|
||||
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two
|
||||
text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap)
|
||||
and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings
|
||||
are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/main/api/pipelines/audioldm2#diffusers.AudioLDM2ProjectionModel).
|
||||
A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively
|
||||
predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding
|
||||
vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2UNet2DConditionModel)
|
||||
of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention
|
||||
conditioning, as in most other LDMs.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called language of audio (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate new state-of-the-art or competitive performance to previous approaches.*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be
|
||||
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
|
||||
|
||||
## Tips
|
||||
|
||||
### Choosing a checkpoint
|
||||
|
||||
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio
|
||||
generation. The third checkpoint is trained exclusively on text-to-music generation.
|
||||
|
||||
All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet.
|
||||
See table below for details on the three checkpoints:
|
||||
|
||||
| Checkpoint | Task | UNet Model Size | Total Model Size | Training Data / h |
|
||||
|-----------------------------------------------------------------|---------------|-----------------|------------------|-------------------|
|
||||
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
|
||||
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
|
||||
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
|
||||
|
||||
### Constructing a prompt
|
||||
|
||||
* Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g. "water stream in a forest" instead of "stream").
|
||||
* It's best to use general terms like "cat" or "dog" instead of specific names or abstract objects the model may not be familiar with.
|
||||
* Using a **negative prompt** can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of "Low quality."
|
||||
|
||||
### Controlling inference
|
||||
|
||||
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
|
||||
|
||||
### Evaluating generated waveforms:
|
||||
|
||||
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
|
||||
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between
|
||||
scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines)
|
||||
section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AudioLDM2Pipeline
|
||||
[[autodoc]] AudioLDM2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioLDM2ProjectionModel
|
||||
[[autodoc]] AudioLDM2ProjectionModel
|
||||
- forward
|
||||
|
||||
## AudioLDM2UNet2DConditionModel
|
||||
[[autodoc]] AudioLDM2UNet2DConditionModel
|
||||
- forward
|
||||
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
@@ -0,0 +1,57 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# MusicLDM
|
||||
|
||||
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
|
||||
MusicLDM takes a text prompt as input and predicts the corresponding music sample.
|
||||
|
||||
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) and [AudioLDM](https://huggingface.co/docs/diffusers/api/pipelines/audioldm/overview),
|
||||
MusicLDM is a text-to-music _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
|
||||
latents.
|
||||
|
||||
MusicLDM is trained on a corpus of 466 hours of music data. Beat-synchronous data augmentation strategies are applied to
|
||||
the music samples, both in the time domain and in the latent space. Using beat-synchronous data augmentation strategies
|
||||
encourages the model to interpolate between the training samples, but stay within the domain of the training data. The
|
||||
result is generated music that is more diverse while staying faithful to the corresponding style.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*In this paper, we present MusicLDM, a state-of-the-art text-to-music model that adapts Stable Diffusion and AudioLDM architectures to the music domain. We achieve this by retraining the contrastive language-audio pretraining model (CLAP) and the Hifi-GAN vocoder, as components of MusicLDM, on a collection of music data samples. Then, we leverage a beat tracking model and propose two different mixup strategies for data augmentation: beat-synchronous audio mixup and beat-synchronous latent mixup, to encourage the model to generate music more diverse while still staying faithful to the corresponding style.*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi).
|
||||
|
||||
## Tips
|
||||
|
||||
When constructing a prompt, keep in mind:
|
||||
|
||||
* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
|
||||
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".
|
||||
|
||||
During inference:
|
||||
|
||||
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between
|
||||
scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines)
|
||||
section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## MusicLDMPipeline
|
||||
[[autodoc]] MusicLDMPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -29,10 +29,11 @@ This model was contributed by the community contributor [HimariO](https://github
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
|
||||
| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
|
||||
|
||||
## Usage example
|
||||
## Usage example with the base model of StableDiffusion-1.4/1.5
|
||||
|
||||
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
|
||||
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
|
||||
All adapters use the same pipeline.
|
||||
|
||||
1. Images are first converted into the appropriate *control image* format.
|
||||
@@ -93,6 +94,62 @@ out_image = pipe(
|
||||
|
||||

|
||||
|
||||
## Usage example with the base model of StableDiffusion-XL
|
||||
|
||||
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference based on StableDiffusion-XL.
|
||||
All adapters use the same pipeline.
|
||||
|
||||
1. Images are first downloaded into the appropriate *control image* format.
|
||||
2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`].
|
||||
|
||||
Let's have a look at a simple example using the [Sketch Adapter](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0).
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
|
||||
sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L")
|
||||
```
|
||||
|
||||

|
||||
|
||||
Then, create the adapter pipeline
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import (
|
||||
T2IAdapter,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
DDPMScheduler
|
||||
)
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
adapter = T2IAdapter.from_pretrained("Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",torch_dtype=torch.float16, adapter_type="full_adapter_xl")
|
||||
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
model_id, adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
|
||||
)
|
||||
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
Finally, pass the prompt and control image to the pipeline
|
||||
|
||||
```py
|
||||
# fix the random seed, so you will get the same result as the example
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
|
||||
sketch_image_out = pipe(
|
||||
prompt="a photo of a dog in real world, high quality",
|
||||
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
|
||||
image=sketch_image,
|
||||
generator=generator,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Available checkpoints
|
||||
|
||||
@@ -113,6 +170,9 @@ Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://hu
|
||||
|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
|
||||
|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
|
||||
|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
|
||||
|[Adapter/t2iadapter, subfolder='sketch_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0)||
|
||||
|[Adapter/t2iadapter, subfolder='canny_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/canny_sdxl_1.0)||
|
||||
|[Adapter/t2iadapter, subfolder='openpose_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0)||
|
||||
|
||||
## Combining multiple adapters
|
||||
|
||||
@@ -185,3 +245,14 @@ However, T2I-Adapter performs slightly worse than ControlNet.
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## StableDiffusionXLAdapterPipeline
|
||||
[[autodoc]] StableDiffusionXLAdapterPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
@@ -10,382 +10,29 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable diffusion XL
|
||||
# Stable Diffusion XL
|
||||
|
||||
Stable Diffusion XL was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, Robin Rombach
|
||||
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared the previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators.*
|
||||
|
||||
## Tips
|
||||
|
||||
- Stable Diffusion XL works especially well with images between 768 and 1024.
|
||||
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
|
||||
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
|
||||
|
||||
### Available checkpoints:
|
||||
|
||||
- *Text-to-Image (1024x1024 resolution)*: [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) with [`StableDiffusionXLPipeline`]
|
||||
- *Image-to-Image / Refiner (1024x1024 resolution)*: [stabilityai/stable-diffusion-xl-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) with [`StableDiffusionXLImg2ImgPipeline`]
|
||||
|
||||
## Usage Example
|
||||
|
||||
Before using SDXL make sure to have `transformers`, `accelerate`, `safetensors` and `invisible_watermark` installed.
|
||||
You can install the libraries as follows:
|
||||
|
||||
```
|
||||
pip install transformers
|
||||
pip install accelerate
|
||||
pip install safetensors
|
||||
```
|
||||
|
||||
### Watermarker
|
||||
|
||||
We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install
|
||||
the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via:
|
||||
|
||||
```
|
||||
pip install invisible-watermark>=0.2.0
|
||||
```
|
||||
|
||||
If the `invisible-watermark` library is installed the watermarker will be used **by default**.
|
||||
|
||||
If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows:
|
||||
|
||||
```py
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
|
||||
```
|
||||
|
||||
### Text-to-Image
|
||||
|
||||
You can use SDXL as follows for *text-to-image*:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt=prompt).images[0]
|
||||
```
|
||||
|
||||
### Image-to-image
|
||||
|
||||
You can use SDXL as follows for *image-to-image*:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
|
||||
|
||||
init_image = load_image(url).convert("RGB")
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt, image=init_image).images[0]
|
||||
```
|
||||
|
||||
### Inpainting
|
||||
|
||||
You can use SDXL as follows for *inpainting*
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
|
||||
```
|
||||
|
||||
### Refining the image output
|
||||
|
||||
In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0),
|
||||
StableDiffusion-XL also includes a [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
|
||||
that is specialized in denoising low-noise stage images to generate images of improved high-frequency quality.
|
||||
This refiner checkpoint can be used as a "second-step" pipeline after having run the base checkpoint to improve
|
||||
image quality.
|
||||
|
||||
When using the refiner, one can easily
|
||||
- 1.) employ the base model and refiner as an *Ensemble of Expert Denoisers* as first proposed in [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/) or
|
||||
- 2.) simply run the refiner in [SDEdit](https://arxiv.org/abs/2108.01073) fashion after the base model.
|
||||
|
||||
**Note**: The idea of using SD-XL base & refiner as an ensemble of experts was first brought forward by
|
||||
a couple community contributors which also helped shape the following `diffusers` implementation, namely:
|
||||
- [SytanSD](https://github.com/SytanSD)
|
||||
- [bghira](https://github.com/bghira)
|
||||
- [Birch-san](https://github.com/Birch-san)
|
||||
- [AmericanPresidentJimmyCarter](https://github.com/AmericanPresidentJimmyCarter)
|
||||
|
||||
#### 1.) Ensemble of Expert Denoisers
|
||||
|
||||
When using the base and refiner model as an ensemble of expert of denoisers, the base model should serve as the
|
||||
expert for the high-noise diffusion stage and the refiner serves as the expert for the low-noise diffusion stage.
|
||||
|
||||
The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly
|
||||
faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised.
|
||||
|
||||
To use the base model and refiner as an ensemble of expert denoisers, make sure to define the span
|
||||
of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise
|
||||
denoising stage (*i.e.* the refiner model) respectively. We can set the intervals using the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) of the base model
|
||||
and [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) of the refiner model.
|
||||
|
||||
For both `denoising_end` and `denoising_start` a float value between 0 and 1 should be passed.
|
||||
When passed, the end and start of denoising will be defined by proportions of discrete timesteps as
|
||||
defined by the model schedule.
|
||||
Note that this will override `strength` if it is also declared, since the number of denoising steps
|
||||
is determined by the discrete timesteps the model was trained on and the declared fractional cutoff.
|
||||
|
||||
Let's look at an example.
|
||||
First, we import the two pipelines. Since the text encoders and variational autoencoder are the same
|
||||
you don't have to load those again for the refiner.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
base = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
base.to("cuda")
|
||||
|
||||
refiner = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=base.text_encoder_2,
|
||||
vae=base.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
refiner.to("cuda")
|
||||
```
|
||||
|
||||
Now we define the number of inference steps and the point at which the model shall be run through the
|
||||
high-noise denoising stage (*i.e.* the base model).
|
||||
|
||||
```py
|
||||
n_steps = 40
|
||||
high_noise_frac = 0.8
|
||||
```
|
||||
|
||||
Stable Diffusion XL base is trained on timesteps 0-999 and Stable Diffusion XL refiner is finetuned
|
||||
from the base model on low noise timesteps 0-199 inclusive, so we use the base model for the first
|
||||
800 timesteps (high noise) and the refiner for the last 200 timesteps (low noise). Hence, `high_noise_frac`
|
||||
is set to 0.8, so that all steps 200-999 (the first 80% of denoising timesteps) are performed by the
|
||||
base model and steps 0-199 (the last 20% of denoising timesteps) are performed by the refiner model.
|
||||
|
||||
Remember, the denoising process starts at **high value** (high noise) timesteps and ends at
|
||||
**low value** (low noise) timesteps.
|
||||
|
||||
Let's run the two pipelines now. Make sure to set `denoising_end` and
|
||||
`denoising_start` to the same values and keep `num_inference_steps` constant. Also remember that
|
||||
the output of the base model should be in latent space:
|
||||
|
||||
```py
|
||||
prompt = "A majestic lion jumping from a big stone at night"
|
||||
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
num_inference_steps=n_steps,
|
||||
denoising_end=high_noise_frac,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = refiner(
|
||||
prompt=prompt,
|
||||
num_inference_steps=n_steps,
|
||||
denoising_start=high_noise_frac,
|
||||
image=image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Let's have a look at the images
|
||||
|
||||
| Original Image | Ensemble of Denoisers Experts |
|
||||
|---|---|
|
||||
|  | 
|
||||
|
||||
If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose):
|
||||
- SDXL works especially well with images between 768 and 1024.
|
||||
- SDXL can pass a different prompt for each of the text encoders it was trained on. We can even pass different parts of the same prompt to the text encoders.
|
||||
- SDXL output images can be improved by making use of a refiner model in an image-to-image setting.
|
||||
- SDXL offers `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to negatively condition the model on image resolution and cropping parameters.
|
||||
|
||||
<Tip>
|
||||
|
||||
The ensemble-of-experts method works well on all available schedulers!
|
||||
To learn how to use SDXL for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](/using-diffusers/sdxl) guide.
|
||||
|
||||
Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
|
||||
|
||||
</Tip>
|
||||
|
||||
#### 2.) Refining the image output from fully denoised base image
|
||||
|
||||
In standard [`StableDiffusionImg2ImgPipeline`]-fashion, the fully-denoised image generated of the base model
|
||||
can be further improved using the [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0).
|
||||
|
||||
For this, you simply run the refiner as a normal image-to-image pipeline after the "base" text-to-image
|
||||
pipeline. You can leave the outputs of the base model in latent space.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
refiner = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=pipe.text_encoder_2,
|
||||
vae=pipe.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
refiner.to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0]
|
||||
image = refiner(prompt=prompt, image=image[None, :]).images[0]
|
||||
```
|
||||
|
||||
| Original Image | Refined Image |
|
||||
|---|---|
|
||||
|  |  |
|
||||
|
||||
<Tip>
|
||||
|
||||
The refiner can also very well be used in an in-painting setting. To do so just make
|
||||
sure you use the [`StableDiffusionXLInpaintPipeline`] classes as shown below
|
||||
|
||||
</Tip>
|
||||
|
||||
To use the refiner for inpainting in the Ensemble of Expert Denoisers setting you can do the following:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=pipe.text_encoder_2,
|
||||
vae=pipe.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
refiner.to("cuda")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
num_inference_steps = 75
|
||||
high_noise_frac = 0.7
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_start=high_noise_frac,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = refiner(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_start=high_noise_frac,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
To use the refiner for inpainting in the standard SDE-style setting, simply remove `denoising_end` and `denoising_start` and choose a smaller
|
||||
number of inference steps for the refiner.
|
||||
|
||||
### Loading single file checkpoints / original file format
|
||||
|
||||
By making use of [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] you can also load the
|
||||
original file format into `diffusers`:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
"./sd_xl_base_1.0.safetensors", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
"./sd_xl_refiner_1.0.safetensors", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
)
|
||||
refiner.to("cuda")
|
||||
```
|
||||
|
||||
### Memory optimization via model offloading
|
||||
|
||||
If you are seeing out-of-memory errors, we recommend making use of [`StableDiffusionXLPipeline.enable_model_cpu_offload`].
|
||||
|
||||
```diff
|
||||
- pipe.to("cuda")
|
||||
+ pipe.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
and
|
||||
|
||||
```diff
|
||||
- refiner.to("cuda")
|
||||
+ refiner.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
### Speed-up inference with `torch.compile`
|
||||
|
||||
You can speed up inference by making use of `torch.compile`. This should give you **ca.** 20% speed-up.
|
||||
|
||||
```diff
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
### Running with `torch < 2.0`
|
||||
|
||||
**Note** that if you want to run Stable Diffusion XL with `torch` < 2.0, please make sure to enable xformers
|
||||
attention:
|
||||
|
||||
```
|
||||
pip install xformers
|
||||
```
|
||||
|
||||
```diff
|
||||
+pipe.enable_xformers_memory_efficient_attention()
|
||||
+refiner.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
## StableDiffusionXLPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLPipeline
|
||||
@@ -403,25 +50,3 @@ pip install xformers
|
||||
[[autodoc]] StableDiffusionXLInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
### Passing different prompts to each text-encoder
|
||||
|
||||
Stable Diffusion XL was trained on two text encoders. The default behavior is to pass the same prompt to each. But it is possible to pass a different prompt for each text-encoder, as [some users](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201) noted that it can boost quality.
|
||||
To do so, you can pass `prompt_2` and `negative_prompt_2` in addition to `prompt` and `negative_prompt`. By doing that, you will pass the original prompts and negative prompts (as in `prompt` and `negative_prompt`) to `text_encoder` (in official SDXL 0.9/1.0 that is [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)),
|
||||
and `prompt_2` and `negative_prompt_2` to `text_encoder_2` (in official SDXL 0.9/1.0 that is [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
# prompt will be passed to OAI CLIP-ViT/L-14
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
# prompt_2 will be passed to OpenCLIP-ViT/bigG-14
|
||||
prompt_2 = "monet painting"
|
||||
image = pipe(prompt=prompt, prompt_2=prompt_2).images[0]
|
||||
```
|
||||
|
||||
@@ -20,6 +20,12 @@ The abstract from the [paper](https://arxiv.org/abs/2303.06555) is:
|
||||
|
||||
You can find the original codebase at [thu-ml/unidiffuser](https://github.com/thu-ml/unidiffuser) and additional checkpoints at [thu-ml](https://huggingface.co/thu-ml).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [dg845](https://github.com/dg845). ❤️
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -334,7 +334,7 @@ image_processor = CLIPImageProcessor.from_pretrained(clip_id)
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)
|
||||
```
|
||||
|
||||
Notice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/pix2pix#diffusers.StableDiffusionInstructPix2PixPipeline.text_encoder).
|
||||
Notice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/transformers/model_doc/clip).
|
||||
|
||||
Next, we prepare a PyTorch `nn.Module` to compute directional similarity:
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ distributed_type: DEEPSPEED
|
||||
|
||||
See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
|
||||
|
||||
<Tip>
|
||||
</Tip>
|
||||
|
||||
Changing the default Adam optimizer to DeepSpeed's Adam
|
||||
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
|
||||
@@ -330,4 +330,4 @@ image.save("./output.png")
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_controlnet_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md).
|
||||
Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_controlnet_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md).
|
||||
|
||||
@@ -276,20 +276,40 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
|
||||
|
||||
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
|
||||
|
||||
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
|
||||
refer to the respective docstrings.
|
||||
<Tip>
|
||||
|
||||
You can also provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`].
|
||||
|
||||
</Tip>
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
We support fine-tuning with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to the following docs:
|
||||
|
||||
* [text_to_image/README_sdxl.md](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md)
|
||||
* [dreambooth/README_sdxl.md](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md)
|
||||
|
||||
## Unloading LoRA parameters
|
||||
|
||||
You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pipeline to unload the LoRA parameters.
|
||||
|
||||
## Supporting A1111 themed LoRA checkpoints from Diffusers
|
||||
## Fusing LoRA parameters
|
||||
|
||||
This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).
|
||||
You can call [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] on a pipeline to merge the LoRA parameters with the original parameters of the underlying model(s). This can lead to a potential speedup in the inference latency.
|
||||
|
||||
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
|
||||
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
|
||||
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
|
||||
## Unfusing LoRA parameters
|
||||
|
||||
To undo `fuse_lora`, call [`~diffusers.loaders.LoraLoaderMixin.unfuse_lora`] on a pipeline.
|
||||
|
||||
## Supporting different LoRA checkpoints from Diffusers
|
||||
|
||||
🤗 Diffusers supports loading checkpoints from popular LoRA trainers such as [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). In this section, we outline the current API's details and limitations.
|
||||
|
||||
### Kohya
|
||||
|
||||
This support was made possible because of the amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).
|
||||
|
||||
We support loading Kohya LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`]. In this section, we explain how to load such a checkpoint from [CivitAI](https://civitai.com/)
|
||||
in Diffusers and perform inference with it.
|
||||
|
||||
First, download a checkpoint. We'll use
|
||||
@@ -356,9 +376,9 @@ lora_filename = "light_and_shadow.safetensors"
|
||||
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
```
|
||||
|
||||
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
|
||||
### Kohya + Stable Diffusion XL
|
||||
|
||||
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
|
||||
After the release of [Stable Diffusion XL](https://huggingface.co/papers/2307.01952), the community contributed some amazing LoRA checkpoints trained on top of it with the Kohya trainer.
|
||||
|
||||
Here are some example checkpoints we tried out:
|
||||
|
||||
@@ -399,14 +419,33 @@ If you notice carefully, the inference UX is exactly identical to what we presen
|
||||
|
||||
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
|
||||
|
||||
### Known limitations specific to the Kohya-styled LoRAs
|
||||
<Tip warning={true}>
|
||||
|
||||
**Known limitations specific to the Kohya LoRAs**:
|
||||
|
||||
* When images don't looks similar to other UIs, such as ComfyUI, it can be because of multiple reasons, as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
|
||||
* We don't fully support [LyCORIS checkpoints](https://github.com/KohakuBlueleaf/LyCORIS). To the best of our knowledge, our current `load_lora_weights()` should support LyCORIS checkpoints that have LoRA and LoCon modules but not the other ones, such as Hada, LoKR, etc.
|
||||
|
||||
## Stable Diffusion XL
|
||||
</Tip>
|
||||
|
||||
We support fine-tuning with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to the following docs:
|
||||
### TheLastBen
|
||||
|
||||
* [text_to_image/README_sdxl.md](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md)
|
||||
* [dreambooth/README_sdxl.md](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md)
|
||||
Here is an example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline_id = "Lykon/dreamshaper-xl-1-0"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_id, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
lora_model_id = "TheLastBen/Papercut_SDXL"
|
||||
lora_filename = "papercut.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
prompt = "papercut sonic"
|
||||
image = pipe(prompt=prompt, num_inference_steps=20, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
@@ -41,6 +41,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
|
||||
13. [Model Editing](#model-editing)
|
||||
14. [DiffEdit](#diffedit)
|
||||
15. [T2I-Adapter](#t2i-adapter)
|
||||
16. [FABRIC](#fabric)
|
||||
|
||||
For convenience, we provide a table to denote which methods are inference-only and which require fine-tuning/training.
|
||||
|
||||
@@ -61,7 +62,7 @@ For convenience, we provide a table to denote which methods are inference-only a
|
||||
| [Model Editing](#model-editing) | ✅ | ❌ | |
|
||||
| [DiffEdit](#diffedit) | ✅ | ❌ | |
|
||||
| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | |
|
||||
|
||||
| [Fabric](#fabric) | ✅ | ❌ | |
|
||||
## Instruct Pix2Pix
|
||||
|
||||
[Paper](https://arxiv.org/abs/2211.09800)
|
||||
@@ -230,3 +231,14 @@ There are 8 canonical pre-trained adapters trained on different conditionings su
|
||||
depth maps, and semantic segmentations.
|
||||
|
||||
See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it.
|
||||
|
||||
## Fabric
|
||||
|
||||
[Paper](https://arxiv.org/abs/2307.10159)
|
||||
|
||||
[Fabric](../api/pipelines/fabric) is a training-free
|
||||
approach applicable to a wide range of popular diffusion models, which exploits
|
||||
the self-attention layer present in the most widely used architectures to condition
|
||||
the diffusion process on a set of feedback images.
|
||||
|
||||
To know more details, check out the [official doc](../api/pipelines/fabric).
|
||||
|
||||
@@ -30,6 +30,7 @@ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
@@ -12,6 +12,6 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Overview
|
||||
|
||||
A pipeline is an end-to-end class that provides a quick and easy way to use a diffusion system for inference by bundling independently trained models and schedulers together. Certain combinations of models and schedulers define specific pipeline types, like [`StableDiffusionPipeline`] or [`StableDiffusionControlNetPipeline`], with specific capabilities. All pipeline types inherit from the base [`DiffusionPipeline`] class; pass it any checkpoint, and it'll automatically detect the pipeline type and load the necessary components.
|
||||
A pipeline is an end-to-end class that provides a quick and easy way to use a diffusion system for inference by bundling independently trained models and schedulers together. Certain combinations of models and schedulers define specific pipeline types, like [`StableDiffusionXLPipeline`] or [`StableDiffusionControlNetPipeline`], with specific capabilities. All pipeline types inherit from the base [`DiffusionPipeline`] class; pass it any checkpoint, and it'll automatically detect the pipeline type and load the necessary components.
|
||||
|
||||
This section introduces you to some of the tasks supported by our pipelines such as unconditional image generation and different techniques and variations of text-to-image generation. You'll also learn how to gain more control over the generation process by setting a seed for reproducibility and weighting prompts to adjust the influence certain words in the prompt has over the output. Finally, you'll see how you can create a community pipeline for a custom task like generating images from speech.
|
||||
This section introduces you to some of the more complex pipelines like Stable Diffusion XL, ControlNet, and DiffEdit, which require additional inputs. You'll also learn how to use a distilled version of the Stable Diffusion model to speed up inference, how to control randomness on your hardware when generating images, and how to create a community pipeline for a custom task like generating images from speech.
|
||||
@@ -0,0 +1,429 @@
|
||||
# Stable Diffusion XL
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) is a powerful text-to-image generation model that iterates on the previous Stable Diffusion models in three key ways:
|
||||
|
||||
1. the UNet is 3x larger and SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly increase the number of parameters
|
||||
2. introduces size and crop-conditioning to preserve training data from being discarded and gain more control over how a generated image should be cropped
|
||||
3. introduces a two-stage model process; the *base* model (can also be run as a standalone model) generates an image as an input to the *refiner* model which adds additional high-quality details
|
||||
|
||||
This guide will show you how to use SDXL for text-to-image, image-to-image, and inpainting.
|
||||
|
||||
Before you begin, make sure you have the following libraries installed:
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install diffusers transformers accelerate safetensors omegaconf invisible-watermark>=0.2.0
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We recommend installing the [invisible-watermark](https://pypi.org/project/invisible-watermark/) library to help identify images that are generated. If the invisible-watermark library is installed, it is used by default. To disable the watermarker:
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## Load model checkpoints
|
||||
|
||||
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
You can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-image
|
||||
|
||||
For text-to-image, pass a text prompt:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipeline(prompt=prompt).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png" alt="generated image of an astronaut in a jungle"/>
|
||||
</div>
|
||||
|
||||
## Image-to-image
|
||||
|
||||
For image-to-image, SDXL works especially well with image sizes between 768x768 and 1024x1024. Pass an initial image, and a text prompt to condition the image with:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImg2Img
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# use from_pipe to avoid consuming additional memory when loading a checkpoint
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png"
|
||||
|
||||
init_image = load_image(url).convert("RGB")
|
||||
prompt = "a dog catching a frisbee in the jungle"
|
||||
image = pipeline(prompt, image=init_image, strength=0.8, guidance_scale=10.5).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png" alt="generated image of a dog catching a frisbee in a jungle"/>
|
||||
</div>
|
||||
|
||||
## Inpainting
|
||||
|
||||
For inpainting, you'll need the original image and a mask of what you want to replace in the original image. Create a prompt to describe what you want to replace the masked area with.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# use from_pipe to avoid consuming additional memory when loading a checkpoint
|
||||
pipeline = AutoPipelineForInpainting.from_pipe(pipeline_text2image).to("cuda")
|
||||
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A deep sea diver floating"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint.png" alt="generated image of a deep sea diver in a jungle"/>
|
||||
</div>
|
||||
|
||||
## Refine image quality
|
||||
|
||||
SDXL includes a [refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) specialized in denoising low-noise stage images to generate higher-quality images from the base model. There are two ways to use the refiner:
|
||||
|
||||
1. use the base and refiner model together to produce a refined image
|
||||
2. use the base model to produce an image, and subsequently use the refiner model to add more details to the image (this is how SDXL is originally trained)
|
||||
|
||||
### Base + refiner model
|
||||
|
||||
When you use the base and refiner model together to generate an image, this is known as an ([*ensemble of expert denoisers*](https://research.nvidia.com/labs/dir/eDiff-I/)). The ensemble of expert denoisers approach requires less overall denoising steps versus passing the base model's output to the refiner model, so it should be significantly faster to run. However, you won't be able to inspect the base model's output because it still contains a large amount of noise.
|
||||
|
||||
As an ensemble of expert denoisers, the base model serves as the expert during the high-noise diffusion stage and the refiner model serves as the expert during the low-noise diffusion stage. Load the base and refiner model:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
base = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=base.text_encoder_2,
|
||||
vae=base.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
To use this approach, you need to define the number of timesteps for each model to run through their respective stages. For the base model, this is controlled by the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) parameter and for the refiner model, it is controlled by the [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) parameter.
|
||||
|
||||
<Tip>
|
||||
|
||||
The `denoising_end` and `denoising_start` parameters should be a float between 0 and 1. These parameters are represented as a proportion of discrete timesteps as defined by the scheduler. If you're also using the `strength` parameter, it'll be ignored because the number of denoising steps is determined by the discrete timesteps the model is trained on and the declared fractional cutoff.
|
||||
|
||||
</Tip>
|
||||
|
||||
Let's set `denoising_end=0.8` so the base model performs the first 80% of denoising the **high-noise** timesteps and set `denoising_start=0.8` so the refiner model performs the last 20% of denoising the **low-noise** timesteps. The base model output should be in **latent** space instead of a PIL image.
|
||||
|
||||
```py
|
||||
prompt = "A majestic lion jumping from a big stone at night"
|
||||
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = refiner(
|
||||
prompt=prompt,
|
||||
num_inference_steps=40,
|
||||
denoising_start=0.8,
|
||||
image=image,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base.png" alt="generated image of a lion on a rock at night" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">base model</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined.png" alt="generated image of a lion on a rock at night in higher quality" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ensemble of expert denoisers</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The refiner model can also be used for inpainting in the [`StableDiffusionXLInpaintPipeline`]:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
base = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=pipe.text_encoder_2,
|
||||
vae=pipe.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
num_inference_steps = 75
|
||||
high_noise_frac = 0.7
|
||||
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_end=high_noise_frac,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = refiner(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
denoising_start=high_noise_frac,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
This ensemble of expert denoisers method works well for all available schedulers!
|
||||
|
||||
### Base to refiner model
|
||||
|
||||
SDXL gets a boost in image quality by using the refiner model to add additional high-quality details to the fully-denoised image from the base model, in an image-to-image setting.
|
||||
|
||||
Load the base and refiner models:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
base = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
text_encoder_2=pipe.text_encoder_2,
|
||||
vae=pipe.vae,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Generate an image from the base model, and set the model output to **latent** space:
|
||||
|
||||
```py
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
image = base(prompt=prompt, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
Pass the generated image to the refiner model:
|
||||
|
||||
```py
|
||||
image = refiner(prompt=prompt, image=image[None, :]).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png" alt="generated image of an astronaut riding a green horse on Mars" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">base model</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png" alt="higher quality generated image of an astronaut riding a green horse on Mars" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">base model + refiner model</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
For inpainting, load the refiner model in the [`StableDiffusionXLInpaintPipeline`], remove the `denoising_end` and `denoising_start` parameters, and choose a smaller number of inference steps for the refiner.
|
||||
|
||||
## Micro-conditioning
|
||||
|
||||
SDXL training involves several additional conditioning techniques, which are referred to as *micro-conditioning*. These include original image size, target image size, and cropping parameters. The micro-conditionings can be used at inference time to create high-quality, centered images.
|
||||
|
||||
<Tip>
|
||||
|
||||
You can use both micro-conditioning and negative micro-conditioning parameters thanks to classifier-free guidance. They are available in the [`StableDiffusionXLPipeline`], [`StableDiffusionXLImg2ImgPipeline`], [`StableDiffusionXLInpaintPipeline`], and [`StableDiffusionXLControlNetPipeline`].
|
||||
|
||||
</Tip>
|
||||
|
||||
### Size conditioning
|
||||
|
||||
There are two types of size conditioning:
|
||||
|
||||
- [`original_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.original_size) conditioning comes from upscaled images in the training batch (because it would be wasteful to discard the smaller images which make up almost 40% of the total training data). This way, SDXL learns that upscaling artifacts are not supposed to be present in high-resolution images. During inference, you can use `original_size` to indicate the original image resolution. Using the default value of `(1024, 1024)` produces higher-quality images that resemble the 1024x1024 images in the dataset. If you choose to use a lower resolution, such as `(256, 256)`, the model still generates 1024x1024 images, but they'll look like the low resolution images (simpler patterns, blurring) in the dataset.
|
||||
|
||||
- [`target_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.target_size) conditioning comes from finetuning SDXL to support different image aspect ratios. During inference, if you use the default value of `(1024, 1024)`, you'll get an image that resembles the composition of square images in the dataset. We recommend using the same value for `target_size` and `original_size`, but feel free to experiment with other options!
|
||||
|
||||
🤗 Diffusers also lets you specify negative conditions about an image's size to steer generation away from certain image resolutions:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_original_size=(512, 512),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-col justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/negative_conditions.png"/>
|
||||
<figcaption class="text-center">Images negative conditioned on image resolutions of (128, 128), (256, 256), and (512, 512).</figcaption>
|
||||
</div>
|
||||
|
||||
### Crop conditioning
|
||||
|
||||
Images generated by previous Stable Diffusion models may sometimes appear to be cropped. This is because images are actually cropped during training so that all the images in a batch have the same size. By conditioning on crop coordinates, SDXL *learns* that no cropping - coordinates `(0, 0)` - usually correlates with centered subjects and complete faces (this is the default value in 🤗 Diffusers). You can experiment with different coordinates if you want to generate off-centered compositions!
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipeline(prompt=prompt, crops_coords_top_left=(256,0)).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-cropped.png" alt="generated image of an astronaut in a jungle, slightly cropped"/>
|
||||
</div>
|
||||
|
||||
You can also specify negative cropping coordinates to steer generation away from certain cropping parameters:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(0, 0),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Use a different prompt for each text-encoder
|
||||
|
||||
SDXL uses two text-encoders, so it is possible to pass a different prompt to each text-encoder, which can [improve quality](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201). Pass your original prompt to `prompt` and the second prompt to `prompt_2` (use `negative_prompt` and `negative_prompt_2` if you're using a negative prompts):
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
# prompt is passed to OAI CLIP-ViT/L-14
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
# prompt_2 is passed to OpenCLIP-ViT/bigG-14
|
||||
prompt_2 = "Van Gogh painting"
|
||||
image = pipeline(prompt=prompt, prompt_2=prompt_2).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-double-prompt.png" alt="generated image of an astronaut in a jungle in the style of a van gogh painting"/>
|
||||
</div>
|
||||
|
||||
## Optimizations
|
||||
|
||||
SDXL is a large model, and you may need to optimize memory to get it to run on your hardware. Here are some tips to save memory and speed up inference.
|
||||
|
||||
1. Offload the model to the CPU with [`~StableDiffusionXLPipeline.enable_model_cpu_offload`] for out-of-memory errors:
|
||||
|
||||
```diff
|
||||
- base.to("cuda")
|
||||
- refiner.to("cuda")
|
||||
+ base.enable_model_cpu_offload
|
||||
+ refiner.enable_model_cpu_offload
|
||||
```
|
||||
|
||||
2. Use `torch.compile` for ~20% speed-up (you need `torch>2.0`):
|
||||
|
||||
```diff
|
||||
+ base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
|
||||
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
3. Enable [xFormers](/optimization/xformers) to run SDXL if `torch<2.0`:
|
||||
|
||||
```diff
|
||||
+ base.enable_xformers_memory_efficient_attention()
|
||||
+ refiner.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
## Other resources
|
||||
|
||||
If you're interested in experimenting with a minimal version of the [`UNet2DConditionModel`] used in SDXL, take a look at the [minSDXL](https://github.com/cloneofsimo/minSDXL) implementation which is written in PyTorch and directly compatible with 🤗 Diffusers.
|
||||
@@ -143,8 +143,8 @@ image
|
||||
A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc('("a red cat, playing with a, ball").and()')
|
||||
generator = torch.Generator(device="cuda").manual_seed(33)
|
||||
prompt_embeds = compel_proc('["a red cat", "playing with a", "ball"].and()')
|
||||
generator = torch.Generator(device="cuda").manual_seed(55)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
|
||||
@@ -15,8 +15,7 @@ specific language governing permissions and limitations under the License.
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242)는 한 주제에 대한 적은 이미지(3~5개)만으로도 stable diffusion과 같이 text-to-image 모델을 개인화할 수 있는 방법입니다. 이를 통해 모델은 다양한 장면, 포즈 및 장면(뷰)에서 피사체에 대해 맥락화(contextualized)된 이미지를 생성할 수 있습니다.
|
||||
|
||||

|
||||
<a href="https://dreambooth.github.io">project's blog.</a></small>
|
||||
<small><a href="https://dreambooth.github.io">프로젝트 블로그</a>에서의 Dreambooth 예시</small>
|
||||
<small>에서의 Dreambooth 예시 <a href="https://dreambooth.github.io">project's blog.</a></small>
|
||||
|
||||
|
||||
이 가이드는 다양한 GPU, Flax 사양에 대해 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 DreamBooth를 파인튜닝하는 방법을 보여줍니다. 더 깊이 파고들어 작동 방식을 확인하는 데 관심이 있는 경우, 이 가이드에 사용된 DreamBooth의 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)에서 찾을 수 있습니다.
|
||||
@@ -472,4 +471,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("dog-bucket.png")
|
||||
```
|
||||
|
||||
[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.
|
||||
[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.
|
||||
|
||||
@@ -39,7 +39,10 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
|
||||
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -1529,6 +1532,44 @@ CLIP guided stable diffusion images mixing pipline allows to combine two images
|
||||
This approach is using (optional) CoCa model to avoid writing image description.
|
||||
[More code examples](https://github.com/TheDenk/images_mixing)
|
||||
|
||||
|
||||
### Stable Diffusion XL Long Weighted Prompt Pipeline
|
||||
|
||||
This SDXL pipeline support unlimted length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
|
||||
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
, torch_dtype = torch.float16
|
||||
, use_safetensors = True
|
||||
, variant = "fp16"
|
||||
, custom_pipeline = "lpw_stable_diffusion_xl",
|
||||
)
|
||||
|
||||
prompt = "photo of a cute (white) cat running on the grass"*20
|
||||
prompt2 = "chasing (birds:1.5)"*20
|
||||
prompt = f"{prompt},{prompt2}"
|
||||
neg_prompt = "blur, low quality, carton, animate"
|
||||
|
||||
pipe.to("cuda")
|
||||
images = pipe(
|
||||
prompt = prompt
|
||||
, negative_prompt = neg_prompt
|
||||
).images[0]
|
||||
|
||||
pipe.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
images
|
||||
```
|
||||
|
||||
In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
|
||||

|
||||
|
||||
## Example Images Mixing (with CoCa)
|
||||
```python
|
||||
import requests
|
||||
@@ -1850,3 +1891,172 @@ for obj in range(bs):
|
||||
|
||||
```
|
||||
|
||||
### Stable Diffusion XL Reference
|
||||
|
||||
This pipeline uses the Reference . Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers.utils import load_image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
# pipe = DiffusionPipeline.from_pretrained(
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
# custom_pipeline="stable_diffusion_xl_reference",
|
||||
# torch_dtype=torch.float16,
|
||||
# use_safetensors=True,
|
||||
# variant="fp16").to('cuda:0')
|
||||
|
||||
pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16").to('cuda:0')
|
||||
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
```
|
||||
|
||||
Reference Image
|
||||
|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: 1 girl`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
|
||||
Reference Image
|
||||

|
||||
|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: A dog`
|
||||
|
||||
`reference_attn=True, reference_adain=False, num_inference_steps=20`
|
||||

|
||||
|
||||
Reference Image
|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: An astronaut riding a lion`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
|
||||
### Stable diffusion fabric pipeline
|
||||
|
||||
FABRIC approach applicable to a wide range of popular diffusion models, which exploits
|
||||
the self-attention layer present in the most widely used architectures to condition
|
||||
the diffusion process on a set of feedback images.
|
||||
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import Diffusionpipeline
|
||||
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "runwayml/stable-diffusion-v1-5"
|
||||
#can also be used with dreamlike-art/dreamlike-photoreal-2.0
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric").to("cuda")
|
||||
|
||||
# let's specify a prompt
|
||||
prompt = "An astronaut riding an elephant"
|
||||
negative_prompt = "lowres, cropped"
|
||||
|
||||
# call the pipeline
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=20,
|
||||
generator=torch.manual_seed(12)
|
||||
).images[0]
|
||||
|
||||
image.save("horse_to_elephant.jpg")
|
||||
|
||||
# let's try another example with feedback
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
prompt = "photo, A blue colored car, fish eye"
|
||||
liked = [init_image]
|
||||
## same goes with disliked
|
||||
|
||||
# call the pipeline
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
liked = liked,
|
||||
num_inference_steps=20,
|
||||
).images[0]
|
||||
|
||||
image.save("black_to_blue.png")
|
||||
```
|
||||
|
||||
*With enough feedbacks you can create very similar high quality images.*
|
||||
|
||||
The original codebase can be found at [sd-fabric/fabric](https://github.com/sd-fabric/fabric), and available checkpoints are [dreamlike-art/dreamlike-photoreal-2.0](https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0), [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), and [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) (may give unexpected results).
|
||||
|
||||
Let's have a look at the images (*512X512*)
|
||||
|
||||
| Without Feedback | With Feedback (1st image) |
|
||||
|---------------------|---------------------|
|
||||
|  |  |
|
||||
|
||||
|
||||
### Masked Im2Im Stable Diffusion Pipeline
|
||||
|
||||
This pipeline reimplements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask.
|
||||
|
||||
```python
|
||||
img = PIL.Image.open("./mech.png")
|
||||
# read image with mask painted over
|
||||
img_paint = PIL.Image.open("./mech_painted.png")
|
||||
neq = numpy.any(numpy.array(img) != numpy.array(img_paint), axis=-1)
|
||||
mask = neq / neq.max()
|
||||
|
||||
pipeline = MaskedStableDiffusionImg2ImgPipeline.from_pretrained("frankjoshua/icbinpICantBelieveIts_v8")
|
||||
|
||||
# works best with EulerAncestralDiscreteScheduler
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
generator = torch.Generator(device="cpu").manual_seed(4)
|
||||
|
||||
prompt = "a man wearing a mask"
|
||||
result = pipeline(prompt=prompt, image=img_paint, mask=mask, strength=0.75,
|
||||
generator=generator)
|
||||
result.images[0].save("result.png")
|
||||
```
|
||||
|
||||
original image mech.png
|
||||
|
||||
<img src=https://github.com/noskill/diffusers/assets/733626/10ad972d-d655-43cb-8de1-039e3d79e849 width="25%" >
|
||||
|
||||
image with mask mech_painted.png
|
||||
|
||||
<img src=https://github.com/noskill/diffusers/assets/733626/c334466a-67fe-4377-9ff7-f46021b9c224 width="25%" >
|
||||
|
||||
result:
|
||||
|
||||
<img src=https://github.com/noskill/diffusers/assets/733626/23a0a71d-51db-471e-926a-107ac62512a8 width="25%" >
|
||||
|
||||
|
||||
@@ -408,7 +408,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
with self.progress_bar(total=num_inference_steps):
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -440,6 +440,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
progress_bar.update()
|
||||
# Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,261 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
debug_save = False
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mask: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. This parameter is modulated by `strength`.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
|
||||
A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# code adapted from parent class StableDiffusionImg2ImgPipeline
|
||||
|
||||
# 0. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 1. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 2. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
# it is sampled from the latent distribution of the VAE
|
||||
latents = self.prepare_latents(
|
||||
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
||||
)
|
||||
|
||||
# mean of the latent distribution
|
||||
init_latents = [
|
||||
self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
# 6. create latent mask
|
||||
latent_mask = self._make_latent_mask(latents, mask)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if latent_mask is not None:
|
||||
latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
|
||||
noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
scaled = latents / self.vae.config.scaling_factor
|
||||
if latent_mask is not None:
|
||||
# scaled = latents / self.vae.config.scaling_factor * latent_mask + init_latents * (1 - latent_mask)
|
||||
scaled = torch.lerp(init_latents, scaled, latent_mask)
|
||||
image = self.vae.decode(scaled, return_dict=False)[0]
|
||||
if self.debug_save:
|
||||
image_gen = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image_gen = self.image_processor.postprocess(image_gen, output_type=output_type, do_denormalize=[True])
|
||||
image_gen[0].save("from_latent.png")
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
def _make_latent_mask(self, latents, mask):
|
||||
if mask is not None:
|
||||
latent_mask = []
|
||||
if not isinstance(mask, list):
|
||||
tmp_mask = [mask]
|
||||
else:
|
||||
tmp_mask = mask
|
||||
_, l_channels, l_height, l_width = latents.shape
|
||||
for m in tmp_mask:
|
||||
if not isinstance(m, PIL.Image.Image):
|
||||
if len(m.shape) == 2:
|
||||
m = m[..., np.newaxis]
|
||||
if m.max() > 1:
|
||||
m = m / 255.0
|
||||
m = self.image_processor.numpy_to_pil(m)[0]
|
||||
if m.mode != "L":
|
||||
m = m.convert("L")
|
||||
resized = self.image_processor.resize(m, l_height, l_width)
|
||||
if self.debug_save:
|
||||
resized.save("latent_mask.png")
|
||||
latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
|
||||
latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
|
||||
latent_mask = latent_mask / latent_mask.max()
|
||||
return latent_mask
|
||||
@@ -0,0 +1,751 @@
|
||||
# Copyright 2023 FABRIC authors and the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import EulerAncestralDiscreteScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> model_id = "dreamlike-art/dreamlike-photoreal-2.0"
|
||||
>>> pipe = DiffusionPipeline(model_id, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric")
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> prompt = "a giant standing in a fantasy landscape best quality"
|
||||
>>> liked = [] # list of images for positive feedback
|
||||
>>> disliked = [] # list of images for negative feedback
|
||||
>>> image = pipe(prompt, num_images=4, liked=liked, disliked=disliked).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class FabricCrossAttnProcessor:
|
||||
def __init__(self):
|
||||
self.attntion_probs = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
weights=None,
|
||||
lora_scale=1.0,
|
||||
):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if isinstance(attn.processor, LoRAAttnProcessor):
|
||||
query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if isinstance(attn.processor, LoRAAttnProcessor):
|
||||
key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
|
||||
if weights is not None:
|
||||
if weights.shape[0] != 1:
|
||||
weights = weights.repeat_interleave(attn.heads, dim=0)
|
||||
attention_probs = attention_probs * weights[:, None]
|
||||
attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True)
|
||||
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
if isinstance(attn.processor, LoRAAttnProcessor):
|
||||
hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
|
||||
else:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FabricPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion and conditioning the results using feedback images.
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`EulerAncestralDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def get_unet_hidden_states(self, z_all, t, prompt_embd):
|
||||
cached_hidden_states = []
|
||||
for module in self.unet.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
|
||||
def new_forward(self, hidden_states, *args, **kwargs):
|
||||
cached_hidden_states.append(hidden_states.clone().detach().cpu())
|
||||
return self.old_forward(hidden_states, *args, **kwargs)
|
||||
|
||||
module.attn1.old_forward = module.attn1.forward
|
||||
module.attn1.forward = new_forward.__get__(module.attn1)
|
||||
|
||||
# run forward pass to cache hidden states, output can be discarded
|
||||
_ = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
|
||||
|
||||
# restore original forward pass
|
||||
for module in self.unet.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = module.attn1.old_forward
|
||||
del module.attn1.old_forward
|
||||
|
||||
return cached_hidden_states
|
||||
|
||||
def unet_forward_with_cached_hidden_states(
|
||||
self,
|
||||
z_all,
|
||||
t,
|
||||
prompt_embd,
|
||||
cached_pos_hiddens: Optional[List[torch.Tensor]] = None,
|
||||
cached_neg_hiddens: Optional[List[torch.Tensor]] = None,
|
||||
pos_weights=(0.8, 0.8),
|
||||
neg_weights=(0.5, 0.5),
|
||||
):
|
||||
if cached_pos_hiddens is None and cached_neg_hiddens is None:
|
||||
return self.unet(z_all, t, encoder_hidden_states=prompt_embd)
|
||||
|
||||
local_pos_weights = torch.linspace(*pos_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
|
||||
local_neg_weights = torch.linspace(*neg_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
|
||||
for block, pos_weight, neg_weight in zip(
|
||||
self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks,
|
||||
local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1],
|
||||
local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1],
|
||||
):
|
||||
for module in block.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
|
||||
def new_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
pos_weight=pos_weight,
|
||||
neg_weight=neg_weight,
|
||||
**kwargs,
|
||||
):
|
||||
cond_hiddens, uncond_hiddens = hidden_states.chunk(2, dim=0)
|
||||
batch_size, d_model = cond_hiddens.shape[:2]
|
||||
device, dtype = hidden_states.device, hidden_states.dtype
|
||||
|
||||
weights = torch.ones(batch_size, d_model, device=device, dtype=dtype)
|
||||
out_pos = self.old_forward(hidden_states)
|
||||
out_neg = self.old_forward(hidden_states)
|
||||
|
||||
if cached_pos_hiddens is not None:
|
||||
cached_pos_hs = cached_pos_hiddens.pop(0).to(hidden_states.device)
|
||||
cond_pos_hs = torch.cat([cond_hiddens, cached_pos_hs], dim=1)
|
||||
pos_weights = weights.clone().repeat(1, 1 + cached_pos_hs.shape[1] // d_model)
|
||||
pos_weights[:, d_model:] = pos_weight
|
||||
attn_with_weights = FabricCrossAttnProcessor()
|
||||
out_pos = attn_with_weights(
|
||||
self,
|
||||
cond_hiddens,
|
||||
encoder_hidden_states=cond_pos_hs,
|
||||
weights=pos_weights,
|
||||
)
|
||||
else:
|
||||
out_pos = self.old_forward(cond_hiddens)
|
||||
|
||||
if cached_neg_hiddens is not None:
|
||||
cached_neg_hs = cached_neg_hiddens.pop(0).to(hidden_states.device)
|
||||
uncond_neg_hs = torch.cat([uncond_hiddens, cached_neg_hs], dim=1)
|
||||
neg_weights = weights.clone().repeat(1, 1 + cached_neg_hs.shape[1] // d_model)
|
||||
neg_weights[:, d_model:] = neg_weight
|
||||
attn_with_weights = FabricCrossAttnProcessor()
|
||||
out_neg = attn_with_weights(
|
||||
self,
|
||||
uncond_hiddens,
|
||||
encoder_hidden_states=uncond_neg_hs,
|
||||
weights=neg_weights,
|
||||
)
|
||||
else:
|
||||
out_neg = self.old_forward(uncond_hiddens)
|
||||
|
||||
out = torch.cat([out_pos, out_neg], dim=0)
|
||||
return out
|
||||
|
||||
module.attn1.old_forward = module.attn1.forward
|
||||
module.attn1.forward = new_forward.__get__(module.attn1)
|
||||
|
||||
out = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
|
||||
|
||||
# restore original forward pass
|
||||
for module in self.unet.modules():
|
||||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1.forward = module.attn1.old_forward
|
||||
del module.attn1.old_forward
|
||||
|
||||
return out
|
||||
|
||||
def preprocess_feedback_images(self, images, vae, dim, device, dtype, generator) -> torch.tensor:
|
||||
images_t = [self.image_to_tensor(img, dim, dtype) for img in images]
|
||||
images_t = torch.stack(images_t).to(device)
|
||||
latents = vae.config.scaling_factor * vae.encode(images_t).latent_dist.sample(generator)
|
||||
|
||||
return torch.cat([latents], dim=0)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=None,
|
||||
liked=None,
|
||||
disliked=None,
|
||||
height=None,
|
||||
width=None,
|
||||
):
|
||||
if prompt is None:
|
||||
raise ValueError("Provide `prompt`. Cannot leave both `prompt` undefined.")
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if liked is not None and not isinstance(liked, list):
|
||||
raise ValueError(f"`liked` has to be of type `list` but is {type(liked)}")
|
||||
|
||||
if disliked is not None and not isinstance(disliked, list):
|
||||
raise ValueError(f"`disliked` has to be of type `list` but is {type(disliked)}")
|
||||
|
||||
if height is not None and not isinstance(height, int):
|
||||
raise ValueError(f"`height` has to be of type `int` but is {type(height)}")
|
||||
|
||||
if width is not None and not isinstance(width, int):
|
||||
raise ValueError(f"`width` has to be of type `int` but is {type(width)}")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = "",
|
||||
negative_prompt: Optional[Union[str, List[str]]] = "lowres, bad anatomy, bad hands, cropped, worst quality",
|
||||
liked: Optional[Union[List[str], List[Image.Image]]] = [],
|
||||
disliked: Optional[Union[List[str], List[Image.Image]]] = [],
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
return_dict: bool = True,
|
||||
num_images: int = 4,
|
||||
guidance_scale: float = 7.0,
|
||||
num_inference_steps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
feedback_start_ratio: float = 0.33,
|
||||
feedback_end_ratio: float = 0.66,
|
||||
min_weight: float = 0.05,
|
||||
max_weight: float = 0.8,
|
||||
neg_scale: float = 0.5,
|
||||
pos_bottleneck_scale: float = 1.0,
|
||||
neg_bottleneck_scale: float = 1.0,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Generate a trajectory of images with binary feedback. The
|
||||
feedback can be given as a list of liked and disliked images.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
liked (`List[Image.Image]` or `List[str]`, *optional*):
|
||||
Encourages images with liked features.
|
||||
disliked (`List[Image.Image]` or `List[str]`, *optional*):
|
||||
Discourages images with disliked features.
|
||||
generator (`torch.Generator` or `List[torch.Generator]` or `int`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) or an `int` to
|
||||
make generation deterministic.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
Height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
Width of the generated image.
|
||||
num_images (`int`, *optional*, defaults to 4):
|
||||
The number of images to generate per prompt.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 20):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
feedback_start_ratio (`float`, *optional*, defaults to `.33`):
|
||||
Start point for providing feedback (between 0 and 1).
|
||||
feedback_end_ratio (`float`, *optional*, defaults to `.66`):
|
||||
End point for providing feedback (between 0 and 1).
|
||||
min_weight (`float`, *optional*, defaults to `.05`):
|
||||
Minimum weight for feedback.
|
||||
max_weight (`float`, *optional*, defults tp `1.0`):
|
||||
Maximum weight for feedback.
|
||||
neg_scale (`float`, *optional*, defaults to `.5`):
|
||||
Scale factor for negative feedback.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.fabric.FabricPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
|
||||
"""
|
||||
|
||||
self.check_inputs(prompt, negative_prompt, liked, disliked)
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.unet.dtype
|
||||
|
||||
if isinstance(prompt, str) and prompt is not None:
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list) and prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = negative_prompt
|
||||
elif isinstance(negative_prompt, list):
|
||||
negative_prompt = negative_prompt
|
||||
else:
|
||||
assert len(negative_prompt) == batch_size
|
||||
|
||||
shape = (
|
||||
batch_size * num_images,
|
||||
self.unet.config.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
latent_noise = randn_tensor(
|
||||
shape,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
positive_latents = (
|
||||
self.preprocess_feedback_images(liked, self.vae, (height, width), device, dtype, generator)
|
||||
if liked and len(liked) > 0
|
||||
else torch.tensor(
|
||||
[],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
negative_latents = (
|
||||
self.preprocess_feedback_images(disliked, self.vae, (height, width), device, dtype, generator)
|
||||
if disliked and len(disliked) > 0
|
||||
else torch.tensor(
|
||||
[],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
do_classifier_free_guidance = guidance_scale > 0.1
|
||||
|
||||
(prompt_neg_embs, prompt_pos_embs) = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
).split([num_images * batch_size, num_images * batch_size])
|
||||
|
||||
batched_prompt_embd = torch.cat([prompt_pos_embs, prompt_neg_embs], dim=0)
|
||||
|
||||
null_tokens = self.tokenizer(
|
||||
[""],
|
||||
return_tensors="pt",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = null_tokens.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
null_prompt_emb = self.text_encoder(
|
||||
input_ids=null_tokens.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
).last_hidden_state
|
||||
|
||||
null_prompt_emb = null_prompt_emb.to(device=device, dtype=dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
latent_noise = latent_noise * self.scheduler.init_noise_sigma
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
ref_start_idx = round(len(timesteps) * feedback_start_ratio)
|
||||
ref_end_idx = round(len(timesteps) * feedback_end_ratio)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as pbar:
|
||||
for i, t in enumerate(timesteps):
|
||||
sigma = self.scheduler.sigma_t[t] if hasattr(self.scheduler, "sigma_t") else 0
|
||||
if hasattr(self.scheduler, "sigmas"):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
|
||||
alpha_hat = 1 / (sigma**2 + 1)
|
||||
|
||||
z_single = self.scheduler.scale_model_input(latent_noise, t)
|
||||
z_all = torch.cat([z_single] * 2, dim=0)
|
||||
z_ref = torch.cat([positive_latents, negative_latents], dim=0)
|
||||
|
||||
if i >= ref_start_idx and i <= ref_end_idx:
|
||||
weight_factor = max_weight
|
||||
else:
|
||||
weight_factor = min_weight
|
||||
|
||||
pos_ws = (weight_factor, weight_factor * pos_bottleneck_scale)
|
||||
neg_ws = (weight_factor * neg_scale, weight_factor * neg_scale * neg_bottleneck_scale)
|
||||
|
||||
if z_ref.size(0) > 0 and weight_factor > 0:
|
||||
noise = torch.randn_like(z_ref)
|
||||
if isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
|
||||
z_ref_noised = (alpha_hat**0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise).type(dtype)
|
||||
else:
|
||||
z_ref_noised = self.scheduler.add_noise(z_ref, noise, t)
|
||||
|
||||
ref_prompt_embd = torch.cat(
|
||||
[null_prompt_emb] * (len(positive_latents) + len(negative_latents)), dim=0
|
||||
)
|
||||
cached_hidden_states = self.get_unet_hidden_states(z_ref_noised, t, ref_prompt_embd)
|
||||
|
||||
n_pos, n_neg = positive_latents.shape[0], negative_latents.shape[0]
|
||||
cached_pos_hs, cached_neg_hs = [], []
|
||||
for hs in cached_hidden_states:
|
||||
cached_pos, cached_neg = hs.split([n_pos, n_neg], dim=0)
|
||||
cached_pos = cached_pos.view(1, -1, *cached_pos.shape[2:]).expand(num_images, -1, -1)
|
||||
cached_neg = cached_neg.view(1, -1, *cached_neg.shape[2:]).expand(num_images, -1, -1)
|
||||
cached_pos_hs.append(cached_pos)
|
||||
cached_neg_hs.append(cached_neg)
|
||||
|
||||
if n_pos == 0:
|
||||
cached_pos_hs = None
|
||||
if n_neg == 0:
|
||||
cached_neg_hs = None
|
||||
else:
|
||||
cached_pos_hs, cached_neg_hs = None, None
|
||||
unet_out = self.unet_forward_with_cached_hidden_states(
|
||||
z_all,
|
||||
t,
|
||||
prompt_embd=batched_prompt_embd,
|
||||
cached_pos_hiddens=cached_pos_hs,
|
||||
cached_neg_hiddens=cached_neg_hs,
|
||||
pos_weights=pos_ws,
|
||||
neg_weights=neg_ws,
|
||||
)[0]
|
||||
|
||||
noise_cond, noise_uncond = unet_out.chunk(2)
|
||||
guidance = noise_cond - noise_uncond
|
||||
noise_pred = noise_uncond + guidance_scale * guidance
|
||||
latent_noise = self.scheduler.step(noise_pred, t, latent_noise)[0]
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
pbar.update()
|
||||
|
||||
y = self.vae.decode(latent_noise / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
imgs = self.image_processor.postprocess(
|
||||
y,
|
||||
output_type=output_type,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return imgs
|
||||
|
||||
return StableDiffusionPipelineOutput(imgs, False)
|
||||
|
||||
def image_to_tensor(self, image: Union[str, Image.Image], dim: tuple, dtype):
|
||||
"""
|
||||
Convert latent PIL image to a torch tensor for further processing.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = self.image_processor.preprocess(image, height=dim[0], width=dim[1])[0]
|
||||
return image.type(dtype)
|
||||
@@ -0,0 +1,805 @@
|
||||
# Based on stable_diffusion_reference.py
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UpBlock2D,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
>>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16").to('cuda:0')
|
||||
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
|
||||
>>> result_img.show()
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def torch_dfs(model: torch.nn.Module):
|
||||
result = [model]
|
||||
for child in model.children():
|
||||
result += torch_dfs(child)
|
||||
return result
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
|
||||
width = (width // 8) * 8
|
||||
|
||||
return height, width
|
||||
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = (image - 0.5) / 0.5
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.stack(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
refimage = refimage.to(dtype=self.vae.dtype)
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
if isinstance(generator, list):
|
||||
ref_image_latents = [
|
||||
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
ref_image_latents = torch.cat(ref_image_latents, dim=0)
|
||||
else:
|
||||
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
|
||||
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
|
||||
|
||||
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
|
||||
if ref_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % ref_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
attention_auto_machine_weight: float = 1.0,
|
||||
gn_auto_machine_weight: float = 1.0,
|
||||
style_fidelity: float = 0.5,
|
||||
reference_attn: bool = True,
|
||||
reference_adain: bool = True,
|
||||
):
|
||||
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
|
||||
|
||||
# 0. Default height and width to unet
|
||||
# height, width = self._default_height_width(height, width, ref_image)
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# 4. Preprocess reference image
|
||||
ref_image = self.prepare_image(
|
||||
image=ref_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
# 7. Prepare reference latent variables
|
||||
ref_image_latents = self.prepare_ref_latents(
|
||||
ref_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Modify self attebtion and group norm
|
||||
MODE = "write"
|
||||
uc_mask = (
|
||||
torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
|
||||
.type_as(ref_image_latents)
|
||||
.bool()
|
||||
)
|
||||
|
||||
def hacked_basic_transformer_inner_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
# 1. Self-Attention
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
if self.only_cross_attention:
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
if MODE == "write":
|
||||
self.bank.append(norm_hidden_states.detach().clone())
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if MODE == "read":
|
||||
if attention_auto_machine_weight > self.attn_weight:
|
||||
attn_output_uc = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
|
||||
# attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output_c = attn_output_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
attn_output_c[uc_mask] = self.attn1(
|
||||
norm_hidden_states[uc_mask],
|
||||
encoder_hidden_states=norm_hidden_states[uc_mask],
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
|
||||
self.bank.clear()
|
||||
else:
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
# 2. Cross-Attention
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_mid_forward(self, *args, **kwargs):
|
||||
eps = 1e-6
|
||||
x = self.original_forward(*args, **kwargs)
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append(mean)
|
||||
self.var_bank.append(var)
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
|
||||
var_acc = sum(self.var_bank) / float(len(self.var_bank))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
x_uc = (((x - mean) / std) * std_acc) + mean_acc
|
||||
x_c = x_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
x_c[uc_mask] = x[uc_mask]
|
||||
x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
return x
|
||||
|
||||
def hack_CrossAttnDownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
eps = 1e-6
|
||||
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
|
||||
eps = 1e-6
|
||||
|
||||
output_states = ()
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
def hacked_CrossAttnUpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
eps = 1e-6
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
eps = 1e-6
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if MODE == "write":
|
||||
if gn_auto_machine_weight >= self.gn_weight:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
self.mean_bank.append([mean])
|
||||
self.var_bank.append([var])
|
||||
if MODE == "read":
|
||||
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
|
||||
var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
|
||||
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
||||
mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
|
||||
var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
|
||||
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
||||
hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
|
||||
hidden_states_c = hidden_states_uc.clone()
|
||||
if do_classifier_free_guidance and style_fidelity > 0:
|
||||
hidden_states_c[uc_mask] = hidden_states[uc_mask]
|
||||
hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
|
||||
|
||||
if MODE == "read":
|
||||
self.mean_bank = []
|
||||
self.var_bank = []
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
if reference_attn:
|
||||
attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
|
||||
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
|
||||
|
||||
for i, module in enumerate(attn_modules):
|
||||
module._original_inner_forward = module.forward
|
||||
module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
|
||||
module.bank = []
|
||||
module.attn_weight = float(i) / float(len(attn_modules))
|
||||
|
||||
if reference_adain:
|
||||
gn_modules = [self.unet.mid_block]
|
||||
self.unet.mid_block.gn_weight = 0
|
||||
|
||||
down_blocks = self.unet.down_blocks
|
||||
for w, module in enumerate(down_blocks):
|
||||
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
|
||||
gn_modules.append(module)
|
||||
|
||||
up_blocks = self.unet.up_blocks
|
||||
for w, module in enumerate(up_blocks):
|
||||
module.gn_weight = float(w) / float(len(up_blocks))
|
||||
gn_modules.append(module)
|
||||
|
||||
for i, module in enumerate(gn_modules):
|
||||
if getattr(module, "original_forward", None) is None:
|
||||
module.original_forward = module.forward
|
||||
if i == 0:
|
||||
# mid_block
|
||||
module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
|
||||
elif isinstance(module, CrossAttnDownBlock2D):
|
||||
module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
|
||||
elif isinstance(module, DownBlock2D):
|
||||
module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
|
||||
elif isinstance(module, CrossAttnUpBlock2D):
|
||||
module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
|
||||
elif isinstance(module, UpBlock2D):
|
||||
module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
|
||||
module.mean_bank = []
|
||||
module.var_bank = []
|
||||
module.gn_weight *= 2
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 10.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# ref only part
|
||||
noise = randn_tensor(
|
||||
ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
|
||||
)
|
||||
ref_xt = self.scheduler.add_noise(
|
||||
ref_image_latents,
|
||||
noise,
|
||||
t.reshape(
|
||||
1,
|
||||
),
|
||||
)
|
||||
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
|
||||
|
||||
MODE = "write"
|
||||
|
||||
self.unet(
|
||||
ref_xt,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
MODE = "read"
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1209,6 +1209,8 @@ def main(args):
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
images = []
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
|
||||
@@ -673,6 +673,8 @@ likely the learning rate can be increased with larger batch sizes.
|
||||
|
||||
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
|
||||
|
||||
`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
|
||||
|
||||
```sh
|
||||
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
|
||||
|
||||
@@ -697,6 +699,7 @@ accelerate launch train_dreambooth.py \
|
||||
--use_8bit_adam \
|
||||
--set_grads_to_none \
|
||||
--skip_save_text_encoder \
|
||||
--validation_scheduler DDPMScheduler \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
@@ -735,6 +738,7 @@ accelerate launch train_dreambooth.py \
|
||||
--text_encoder_use_attention_mask \
|
||||
--validation_images $VALIDATION_IMAGES \
|
||||
--class_labels_conditioning timesteps \
|
||||
--validation_scheduler DDPMScheduler\
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import argparse
|
||||
import copy
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -47,7 +48,6 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -153,7 +153,9 @@ def log_validation(
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
module = importlib.import_module("diffusers")
|
||||
scheduler_class = getattr(module, args.validation_scheduler)
|
||||
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -556,6 +558,13 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_scheduler",
|
||||
type=str,
|
||||
default="DPMSolverMultistepScheduler",
|
||||
choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
|
||||
help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -999,7 +999,7 @@ def main(args):
|
||||
validation_prompt_encoder_hidden_states = None
|
||||
|
||||
if args.class_prompt is not None:
|
||||
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
|
||||
pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
|
||||
else:
|
||||
pre_computed_class_prompt_encoder_hidden_states = None
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -843,11 +843,15 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
+153
-1
@@ -122,7 +122,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
|
||||
|
||||
def test_dreambooth(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -421,6 +421,77 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--train_text_encoder
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
@@ -828,6 +899,87 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
|
||||
prompt = "a prompt"
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
prompt = "a prompt"
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path {pipeline_path}
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--train_text_encoder
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
prompt = "a prompt"
|
||||
|
||||
@@ -50,12 +50,12 @@ When running `accelerate config`, if we specify torch compile mode to True there
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export VAE="madebyollin/sdxl-vae-fp16-fix"
|
||||
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE \
|
||||
--pretrained_vae_model_name_or_path=$VAE_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
@@ -78,6 +78,7 @@ accelerate launch train_text_to_image_sdxl.py \
|
||||
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
|
||||
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
|
||||
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
|
||||
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
### Inference
|
||||
|
||||
@@ -111,12 +112,13 @@ on consumer GPUs like Tesla T4, Tesla V100.
|
||||
|
||||
### Training
|
||||
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
```
|
||||
|
||||
@@ -132,11 +134,13 @@ Now we can start training!
|
||||
```bash
|
||||
accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE_NAME \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=1024 --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--num_train_epochs=2 --checkpointing_steps=500 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--mixed_precision="fp16" \
|
||||
--seed=42 \
|
||||
--output_dir="sd-pokemon-model-lora-sdxl" \
|
||||
--validation_prompt="cute dragon creature" --report_to="wandb" \
|
||||
@@ -145,6 +149,10 @@ accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
|
||||
The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
|
||||
|
||||
**Notes**:
|
||||
|
||||
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
### Finetuning the text encoder and UNet
|
||||
|
||||
The script also allows you to finetune the `text_encoder` along with the `unet`.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
accelerate>=0.16.0
|
||||
accelerate>=0.22.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.19.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -396,16 +396,6 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp32", "fp16", "bf16"],
|
||||
help=(
|
||||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
@@ -724,11 +714,15 @@ def main(args):
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1002,9 +996,12 @@ def main(args):
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
else:
|
||||
pixel_values = batch["pixel_values"]
|
||||
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
@@ -1147,13 +1144,6 @@ def main(args):
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1198,14 +1188,13 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_attn_processors_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
@@ -1217,14 +1206,15 @@ def main(args):
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
del unet
|
||||
del text_encoder_one
|
||||
del text_encoder_two
|
||||
del text_encoder_lora_layers
|
||||
del text_encoder_2_lora_layers
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -887,7 +887,12 @@ def main():
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
weight_name = (
|
||||
f"learned_embeds-steps-{global_step}.bin"
|
||||
if args.no_safe_serialization
|
||||
else f"learned_embeds-steps-{global_step}.safetensors"
|
||||
)
|
||||
save_path = os.path.join(args.output_dir, weight_name)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_ids,
|
||||
@@ -952,7 +957,8 @@ def main():
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
|
||||
save_path = os.path.join(args.output_dir, weight_name)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_ids,
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.0.dev0")
|
||||
check_min_version("0.21.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -0,0 +1,340 @@
|
||||
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
||||
# *Only* converts the UNet, VAE, and Text Encoder.
|
||||
# Does not convert optimizer state or any other thing.
|
||||
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
# =================#
|
||||
|
||||
unet_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("out.0.weight", "conv_norm_out.weight"),
|
||||
("out.0.bias", "conv_norm_out.bias"),
|
||||
("out.2.weight", "conv_out.weight"),
|
||||
("out.2.bias", "conv_out.bias"),
|
||||
# the following are for sdxl
|
||||
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
||||
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
||||
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
||||
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
# hardcoded number of downblocks and resnets/attentions...
|
||||
# would need smarter logic for other networks.
|
||||
for i in range(3):
|
||||
# loop over downblocks/upblocks
|
||||
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
def convert_unet_state_dict(unet_state_dict):
|
||||
# buyer beware: this is a *brittle* function,
|
||||
# and correct output requires that all of these pieces interact in
|
||||
# the exact order in which I have arranged them.
|
||||
mapping = {k: k for k in unet_state_dict.keys()}
|
||||
for sd_name, hf_name in unet_conversion_map:
|
||||
mapping[hf_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, hf_part in unet_conversion_map_resnet:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in unet_conversion_map_layer:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
("norm_out", "conv_norm_out"),
|
||||
("mid.attn_1.", "mid_block.attentions.0."),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
# down_blocks have two resnets
|
||||
for j in range(2):
|
||||
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
||||
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
||||
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
||||
sd_downsample_prefix = f"down.{i}.downsample."
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
# the following are for SDXL
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
for sd_part, hf_part in vae_conversion_map_attn:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# =========================#
|
||||
# Text Encoder Conversion #
|
||||
# =========================#
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("transformer.resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
|
||||
def convert_openclip_text_enc_state_dict(text_enc_dict):
|
||||
new_state_dict = {}
|
||||
capture_qkv_weight = {}
|
||||
capture_qkv_bias = {}
|
||||
for k, v in text_enc_dict.items():
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.weight")
|
||||
or k.endswith(".self_attn.k_proj.weight")
|
||||
or k.endswith(".self_attn.v_proj.weight")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.weight")]
|
||||
k_code = k[-len("q_proj.weight")]
|
||||
if k_pre not in capture_qkv_weight:
|
||||
capture_qkv_weight[k_pre] = [None, None, None]
|
||||
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
if (
|
||||
k.endswith(".self_attn.q_proj.bias")
|
||||
or k.endswith(".self_attn.k_proj.bias")
|
||||
or k.endswith(".self_attn.v_proj.bias")
|
||||
):
|
||||
k_pre = k[: -len(".q_proj.bias")]
|
||||
k_code = k[-len("q_proj.bias")]
|
||||
if k_pre not in capture_qkv_bias:
|
||||
capture_qkv_bias[k_pre] = [None, None, None]
|
||||
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
||||
continue
|
||||
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
||||
new_state_dict[relabelled_key] = v
|
||||
|
||||
for k_pre, tensors in capture_qkv_weight.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
||||
|
||||
for k_pre, tensors in capture_qkv_bias.items():
|
||||
if None in tensors:
|
||||
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
||||
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
||||
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_openai_text_enc_state_dict(text_enc_dict):
|
||||
return text_enc_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||
parser.add_argument(
|
||||
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.model_path is not None, "Must provide a model path!"
|
||||
|
||||
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
||||
|
||||
# Path for safetensors
|
||||
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
|
||||
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
|
||||
|
||||
# Load models from safetensors if it exists, if it doesn't pytorch
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = load_file(unet_path, device="cpu")
|
||||
else:
|
||||
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = load_file(vae_path, device="cpu")
|
||||
else:
|
||||
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = load_file(text_enc_path, device="cpu")
|
||||
else:
|
||||
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
|
||||
if osp.exists(text_enc_2_path):
|
||||
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
|
||||
else:
|
||||
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
|
||||
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
if args.half:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
if args.use_safetensors:
|
||||
save_file(state_dict, args.checkpoint_path)
|
||||
else:
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, args.checkpoint_path)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -152,7 +152,7 @@ if __name__ == "__main__":
|
||||
pipeline_class = None
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
checkpoint_path_or_dict=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
image_size=args.image_size,
|
||||
prediction_type=args.prediction_type,
|
||||
|
||||
@@ -44,6 +44,11 @@ To create the package for pypi.
|
||||
For the sources, run: "python setup.py sdist"
|
||||
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
||||
|
||||
Long story cut short, you need to run both before you can upload the distribution to the
|
||||
test pypi and the actual pypi servers:
|
||||
|
||||
python setup.py bdist_wheel && python setup.py sdist
|
||||
|
||||
8. Check that everything looks correct by uploading the package to the pypi test server:
|
||||
|
||||
twine upload dist/* -r pypitest
|
||||
@@ -54,14 +59,20 @@ To create the package for pypi.
|
||||
Check that you can install it in a virtualenv by running:
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
|
||||
If you are testing from a Colab Notebook, for instance, then do:
|
||||
pip install diffusers && pip uninstall diffusers
|
||||
pip install -i https://testpypi.python.org/pypi diffusers
|
||||
|
||||
Check you can run the following commands:
|
||||
python -c "from diffusers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
|
||||
python -c "python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
python -c "from diffusers import *"
|
||||
|
||||
9. Upload the final version to actual pypi:
|
||||
twine upload dist/* -r pypi
|
||||
|
||||
10. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||
10. Prepare the release notes and publish them on github once everything is looking hunky-dory.
|
||||
|
||||
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
|
||||
you need to go back to main before executing this.
|
||||
@@ -233,11 +244,11 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.20.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="deep learning",
|
||||
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
|
||||
license="Apache",
|
||||
author="The HuggingFace team",
|
||||
author_email="patrick@huggingface.co",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.20.0.dev0"
|
||||
__version__ = "0.21.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -133,6 +133,9 @@ else:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
@@ -160,6 +163,7 @@ else:
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
@@ -187,6 +191,8 @@ else:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
|
||||
@@ -24,6 +24,16 @@ from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
|
||||
|
||||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
torch.FloatTensor,
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
List[torch.FloatTensor],
|
||||
]
|
||||
|
||||
|
||||
class VaeImageProcessor(ConfigMixin):
|
||||
"""
|
||||
Image processor for VAE.
|
||||
@@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image to [-1,1].
|
||||
do_binarize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to binarize the image to 0/1.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to RGB format.
|
||||
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
||||
Whether to convert the images to grayscale format.
|
||||
"""
|
||||
|
||||
config_name = CONFIG_NAME
|
||||
@@ -51,9 +65,18 @@ class VaeImageProcessor(ConfigMixin):
|
||||
vae_scale_factor: int = 8,
|
||||
resample: str = "lanczos",
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
do_convert_rgb: bool = False,
|
||||
do_convert_grayscale: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if do_convert_rgb and do_convert_grayscale:
|
||||
raise ValueError(
|
||||
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
||||
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
||||
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
||||
)
|
||||
self.config.do_convert_rgb = False
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
||||
@@ -119,29 +142,95 @@ class VaeImageProcessor(ConfigMixin):
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Converts an image to RGB format.
|
||||
Converts a PIL image to RGB format.
|
||||
"""
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def resize(
|
||||
@staticmethod
|
||||
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Converts a PIL image to grayscale format.
|
||||
"""
|
||||
image = image.convert("L")
|
||||
|
||||
return image
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> PIL.Image.Image:
|
||||
):
|
||||
"""
|
||||
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
This function return the height and width that are downscaled to the next integer multiple of
|
||||
`vae_scale_factor`.
|
||||
|
||||
Args:
|
||||
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
||||
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
||||
have shape `[batch, channel, height, width]`.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the height of `image` input.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use the width of the `image` input.
|
||||
"""
|
||||
|
||||
if height is None:
|
||||
height = image.height
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
else:
|
||||
height = image.shape[1]
|
||||
|
||||
if width is None:
|
||||
width = image.width
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
else:
|
||||
height = image.shape[2]
|
||||
|
||||
width, height = (
|
||||
x - x % self.config.vae_scale_factor for x in (width, height)
|
||||
) # resize to integer multiple of vae_scale_factor
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
|
||||
return height, width
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
"""
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = self.numpy_to_pt(image)
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
create a mask
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@@ -154,6 +243,25 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
||||
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
||||
if isinstance(image, torch.Tensor):
|
||||
# if image is a pytorch tensor could have 2 possible shapes:
|
||||
# 1. batch x height x width: we should insert the channel dimension at position 1
|
||||
# 2. channnel x height x width: we should insert batch dimension at position 0,
|
||||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
||||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
||||
image = image.unsqueeze(1)
|
||||
else:
|
||||
# if it is a numpy array, it could have 2 possible shapes:
|
||||
# 1. batch x height x width: insert channel dimension on last position
|
||||
# 2. height x width x channel: insert batch dimension on first position
|
||||
if image.shape[-1] == 1:
|
||||
image = np.expand_dims(image, axis=0)
|
||||
else:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
|
||||
if isinstance(image, supported_formats):
|
||||
image = [image]
|
||||
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
||||
@@ -164,42 +272,41 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
elif self.config.do_convert_grayscale:
|
||||
image = [self.convert_to_grayscale(i) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_default_height_width(image[0], height, width)
|
||||
image = [self.resize(i, height, width) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
|
||||
image = self.numpy_to_pt(image)
|
||||
_, _, height, width = image.shape
|
||||
if self.config.do_resize and (
|
||||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
|
||||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||
)
|
||||
|
||||
height, width = self.get_default_height_width(image, height, width)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
_, channel, height, width = image.shape
|
||||
|
||||
if self.config.do_convert_grayscale and image.ndim == 3:
|
||||
image = image.unsqueeze(1)
|
||||
|
||||
channel = image.shape[1]
|
||||
# don't need any preprocess if the image is latents
|
||||
if channel == 4:
|
||||
return image
|
||||
|
||||
if self.config.do_resize and (
|
||||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
|
||||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
||||
)
|
||||
height, width = self.get_default_height_width(image, height, width)
|
||||
if self.config.do_resize:
|
||||
image = self.resize(image, height, width)
|
||||
|
||||
# expected range [0,1], normalize to [-1,1]
|
||||
do_normalize = self.config.do_normalize
|
||||
if image.min() < 0:
|
||||
if image.min() < 0 and do_normalize:
|
||||
warnings.warn(
|
||||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
||||
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
||||
@@ -210,6 +317,9 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if do_normalize:
|
||||
image = self.normalize(image)
|
||||
|
||||
if self.config.do_binarize:
|
||||
image = self.binarize(image)
|
||||
|
||||
return image
|
||||
|
||||
def postprocess(
|
||||
|
||||
+347
-168
@@ -24,8 +24,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import requests
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from torch import nn
|
||||
|
||||
from .utils import (
|
||||
@@ -86,7 +85,58 @@ class PatchedLoraProjection(nn.Module):
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
|
||||
# when saving the whole text encoder model and when LoRA is unloaded or fused
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
if self.lora_linear_layer is None:
|
||||
return self.regular_linear_layer.state_dict(
|
||||
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def _fuse_lora(self):
|
||||
if self.lora_linear_layer is None:
|
||||
return
|
||||
|
||||
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
|
||||
|
||||
w_orig = self.regular_linear_layer.weight.data.float()
|
||||
w_up = self.lora_linear_layer.up.weight.data.float()
|
||||
w_down = self.lora_linear_layer.down.weight.data.float()
|
||||
|
||||
if self.lora_linear_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
|
||||
|
||||
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
self.lora_linear_layer = None
|
||||
|
||||
# offload the up and down matrices to CPU to not blow the memory
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
|
||||
fused_weight = self.regular_linear_layer.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.device
|
||||
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, input):
|
||||
if self.lora_linear_layer is None:
|
||||
return self.regular_linear_layer(input)
|
||||
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
|
||||
|
||||
|
||||
@@ -231,15 +281,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
CustomDiffusionAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
|
||||
|
||||
@@ -314,24 +356,14 @@ class UNet2DConditionLoadersMixin:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# fill attn processors
|
||||
attn_processors = {}
|
||||
non_attn_lora_layers = []
|
||||
lora_layers_list = []
|
||||
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
)
|
||||
if is_new_lora_format:
|
||||
# Strip the `"unet"` prefix.
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
warnings.warn(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
# correct keys
|
||||
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
mapped_network_alphas = {}
|
||||
@@ -367,87 +399,38 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
if "lora.down.weight" in value_dict:
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
non_attn_lora_layers.append((attn_processor, lora))
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
# To handle SDXL.
|
||||
rank_mapping = {}
|
||||
hidden_size_mapping = {}
|
||||
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
|
||||
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
|
||||
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
|
||||
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
|
||||
|
||||
if isinstance(
|
||||
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
||||
):
|
||||
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
||||
attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
||||
attn_processor_class = LoRAXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
|
||||
if attn_processor_class is not LoRAAttnAddedKVProcessor:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight"),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
q_rank=rank_mapping.get("to_q_lora.down.weight"),
|
||||
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
|
||||
v_rank=rank_mapping.get("to_v_lora.down.weight"),
|
||||
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
|
||||
out_rank=rank_mapping.get("to_out_lora.down.weight"),
|
||||
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
|
||||
)
|
||||
else:
|
||||
attn_processors[key] = attn_processor_class(
|
||||
rank=rank_mapping.get("to_k_lora.down.weight", None),
|
||||
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
@@ -475,22 +458,47 @@ class UNet2DConditionLoadersMixin:
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
self.set_attn_processor(attn_processors)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]
|
||||
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
|
||||
|
||||
# set layers
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
# set ff layers
|
||||
for target_module, lora_layer in non_attn_lora_layers:
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
)
|
||||
if is_new_lora_format:
|
||||
# Strip the `"unet"` prefix.
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
logger.warn(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
# change processor format to 'pure' LoRACompatibleLinear format
|
||||
if any("processor" in k.split(".") for k in state_dict.keys()):
|
||||
|
||||
def format_to_lora_compatible(key):
|
||||
if "processor" not in key.split("."):
|
||||
return key
|
||||
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
|
||||
|
||||
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
|
||||
return state_dict, network_alphas
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -568,6 +576,20 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self):
|
||||
self.apply(self._fuse_lora_apply)
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora()
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
def _unfuse_lora_apply(self, module):
|
||||
if hasattr(module, "_unfuse_lora"):
|
||||
module._unfuse_lora()
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin:
|
||||
r"""
|
||||
@@ -1021,6 +1043,13 @@ class LoraLoaderMixin:
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
@@ -1041,7 +1070,12 @@ class LoraLoaderMixin:
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
@@ -1072,30 +1106,78 @@ class LoraLoaderMixin:
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
|
||||
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [
|
||||
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
||||
]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
@classmethod
|
||||
def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
# 1. get all state_dict_keys
|
||||
all_keys = list(state_dict.keys())
|
||||
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
|
||||
|
||||
# 2. check if needs remapping, if not return original dict
|
||||
is_in_sgm_format = False
|
||||
for key in all_keys:
|
||||
if any(p in key for p in sgm_patterns):
|
||||
is_in_sgm_format = True
|
||||
break
|
||||
|
||||
if not is_in_sgm_format:
|
||||
return state_dict
|
||||
|
||||
# 3. Else remap from SGM patterns
|
||||
new_state_dict = {}
|
||||
inner_block_map = ["resnets", "attentions", "upsamplers"]
|
||||
|
||||
# Retrieves # of down, mid and up blocks
|
||||
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
|
||||
for layer in state_dict:
|
||||
if "text" not in layer:
|
||||
|
||||
for layer in all_keys:
|
||||
if "text" in layer:
|
||||
new_state_dict[layer] = state_dict.pop(layer)
|
||||
else:
|
||||
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
|
||||
if "input_blocks" in layer:
|
||||
if sgm_patterns[0] in layer:
|
||||
input_block_ids.add(layer_id)
|
||||
elif "middle_block" in layer:
|
||||
elif sgm_patterns[1] in layer:
|
||||
middle_block_ids.add(layer_id)
|
||||
elif "output_blocks" in layer:
|
||||
elif sgm_patterns[2] in layer:
|
||||
output_block_ids.add(layer_id)
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported")
|
||||
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
|
||||
|
||||
input_blocks = {
|
||||
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
|
||||
@@ -1158,12 +1240,8 @@ class LoraLoaderMixin:
|
||||
)
|
||||
new_state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
if is_all_unet and len(state_dict) > 0:
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError("At this point all state dict entries have to be converted.")
|
||||
else:
|
||||
# Remaining is the text encoder state dict.
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict.update({k: v})
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -1203,7 +1281,7 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
# load loras into unet
|
||||
@@ -1245,6 +1323,7 @@ class LoraLoaderMixin:
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
@@ -1283,10 +1362,17 @@ class LoraLoaderMixin:
|
||||
f"{name}.out_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
||||
|
||||
rank = text_encoder_lora_state_dict[
|
||||
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
|
||||
].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -1328,15 +1414,15 @@ class LoraLoaderMixin:
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj = attn_module.q_proj.regular_linear_layer
|
||||
attn_module.k_proj = attn_module.k_proj.regular_linear_layer
|
||||
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
|
||||
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
attn_module.k_proj.lora_linear_layer = None
|
||||
attn_module.v_proj.lora_linear_layer = None
|
||||
attn_module.out_proj.lora_linear_layer = None
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
|
||||
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
|
||||
mlp_module.fc1.lora_linear_layer = None
|
||||
mlp_module.fc2.lora_linear_layer = None
|
||||
|
||||
@classmethod
|
||||
def _modify_text_encoder(
|
||||
@@ -1344,7 +1430,7 @@ class LoraLoaderMixin:
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alphas=None,
|
||||
rank=4,
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
):
|
||||
@@ -1365,38 +1451,76 @@ class LoraLoaderMixin:
|
||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
||||
|
||||
if isinstance(rank, dict):
|
||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
q_linear_layer = (
|
||||
attn_module.q_proj.regular_linear_layer
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection)
|
||||
else attn_module.q_proj
|
||||
)
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
|
||||
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
k_linear_layer = (
|
||||
attn_module.k_proj.regular_linear_layer
|
||||
if isinstance(attn_module.k_proj, PatchedLoraProjection)
|
||||
else attn_module.k_proj
|
||||
)
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
|
||||
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
v_linear_layer = (
|
||||
attn_module.v_proj.regular_linear_layer
|
||||
if isinstance(attn_module.v_proj, PatchedLoraProjection)
|
||||
else attn_module.v_proj
|
||||
)
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
|
||||
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
out_linear_layer = (
|
||||
attn_module.out_proj.regular_linear_layer
|
||||
if isinstance(attn_module.out_proj, PatchedLoraProjection)
|
||||
else attn_module.out_proj
|
||||
)
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
|
||||
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
||||
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
fc1_linear_layer = (
|
||||
mlp_module.fc1.regular_linear_layer
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection)
|
||||
else mlp_module.fc1
|
||||
)
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
|
||||
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
fc2_linear_layer = (
|
||||
mlp_module.fc2.regular_linear_layer
|
||||
if isinstance(mlp_module.fc2, PatchedLoraProjection)
|
||||
else mlp_module.fc2
|
||||
)
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
|
||||
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
@@ -1676,40 +1800,90 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
from .models.attention_processor import (
|
||||
LORA_ATTENTION_PROCESSORS,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
|
||||
unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}
|
||||
|
||||
if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
|
||||
# Handle attention processors that are a mix of regular attention and AddedKV
|
||||
# attention.
|
||||
if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes:
|
||||
self.unet.set_default_attn_processor()
|
||||
else:
|
||||
regular_attention_classes = {
|
||||
LoRAAttnProcessor: AttnProcessor,
|
||||
LoRAAttnProcessor2_0: AttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor: XFormersAttnProcessor,
|
||||
}
|
||||
[attention_proc_class] = unet_attention_classes
|
||||
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
|
||||
|
||||
for _, module in self.unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
for _, module in self.unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
|
||||
fuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora()
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora()
|
||||
attn_module.k_proj._fuse_lora()
|
||||
attn_module.v_proj._fuse_lora()
|
||||
attn_module.out_proj._fuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora()
|
||||
mlp_module.fc2._fuse_lora()
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if unfuse_unet:
|
||||
self.unet.unfuse_lora()
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
unfuse_text_encoder_lora(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
unfuse_text_encoder_lora(self.text_encoder_2)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
@@ -1790,6 +1964,9 @@ class FromSingleFileMixin:
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
@@ -1820,6 +1997,7 @@ class FromSingleFileMixin:
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -1936,6 +2114,7 @@ class FromSingleFileMixin:
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
original_config_file=original_config_file,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
|
||||
@@ -11,17 +11,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import Downsample2D
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MultiAdapter(ModelMixin):
|
||||
r"""
|
||||
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
||||
@@ -41,6 +45,31 @@ class MultiAdapter(ModelMixin):
|
||||
self.num_adapter = len(adapters)
|
||||
self.adapters = nn.ModuleList(adapters)
|
||||
|
||||
if len(adapters) == 0:
|
||||
raise ValueError("Expecting at least one adapter")
|
||||
|
||||
if len(adapters) == 1:
|
||||
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
|
||||
|
||||
# The outputs from each adapter are added together with a weight
|
||||
# This means that the change in dimenstions from downsampling must
|
||||
# be the same for all adapters. Inductively, it also means the total
|
||||
# downscale factor must also be the same for all adapters.
|
||||
|
||||
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
|
||||
|
||||
for idx in range(1, len(adapters)):
|
||||
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor
|
||||
|
||||
if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
|
||||
raise ValueError(
|
||||
f"Expecting all adapters to have the same total_downscale_factor, "
|
||||
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
|
||||
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
|
||||
)
|
||||
|
||||
self.total_downscale_factor = adapters[0].total_downscale_factor
|
||||
|
||||
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -56,14 +85,8 @@ class MultiAdapter(ModelMixin):
|
||||
else:
|
||||
adapter_weights = torch.tensor(adapter_weights)
|
||||
|
||||
if xs.shape[1] % self.num_adapter != 0:
|
||||
raise ValueError(
|
||||
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
|
||||
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
|
||||
)
|
||||
x_list = torch.chunk(xs, self.num_adapter, dim=1)
|
||||
accume_state = None
|
||||
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
|
||||
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
|
||||
features = adapter(x)
|
||||
if accume_state is None:
|
||||
accume_state = features
|
||||
@@ -72,6 +95,119 @@ class MultiAdapter(ModelMixin):
|
||||
accume_state[i] += w * features[i]
|
||||
return accume_state
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
idx = 0
|
||||
model_path_to_save = save_directory
|
||||
for adapter in self.adapters:
|
||||
adapter.save_pretrained(
|
||||
model_path_to_save,
|
||||
is_main_process=is_main_process,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
idx += 1
|
||||
model_path_to_save = model_path_to_save + f"_{idx}"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
adapters = []
|
||||
|
||||
# load adapter and append to list until no adapter directory exists anymore
|
||||
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
|
||||
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
|
||||
model_path_to_load = pretrained_model_path
|
||||
while os.path.isdir(model_path_to_load):
|
||||
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
|
||||
adapters.append(adapter)
|
||||
|
||||
idx += 1
|
||||
model_path_to_load = pretrained_model_path + f"_{idx}"
|
||||
|
||||
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
|
||||
|
||||
if len(adapters) == 0:
|
||||
raise ValueError(
|
||||
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
||||
)
|
||||
|
||||
return cls(adapters)
|
||||
|
||||
|
||||
class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
@@ -109,6 +245,8 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
|
||||
if adapter_type == "full_adapter":
|
||||
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
elif adapter_type == "full_adapter_xl":
|
||||
self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
elif adapter_type == "light_adapter":
|
||||
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
|
||||
else:
|
||||
@@ -165,6 +303,48 @@ class FullAdapter(nn.Module):
|
||||
return features
|
||||
|
||||
|
||||
class FullAdapterXL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: List[int] = [320, 640, 1280, 1280],
|
||||
num_res_blocks: int = 2,
|
||||
downscale_factor: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
in_channels = in_channels * downscale_factor**2
|
||||
|
||||
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
||||
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.body = []
|
||||
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
|
||||
for i in range(len(channels)):
|
||||
if i == 1:
|
||||
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
|
||||
elif i == 2:
|
||||
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
|
||||
else:
|
||||
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
|
||||
|
||||
self.body = nn.ModuleList(self.body)
|
||||
# XL has one fewer downsampling
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
features = []
|
||||
|
||||
for block in self.body:
|
||||
x = block(x)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class AdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
super().__init__()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -175,8 +175,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
@@ -137,6 +137,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self.latent_shift = latent_shift
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.spatial_scale_factor = 2**out_channels
|
||||
self.tile_overlap_factor = 0.125
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -149,11 +158,147 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
||||
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
|
||||
plain `tuple` is returned.
|
||||
"""
|
||||
# scale of encoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_sample_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tile
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
|
||||
tile = self.encoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = blend_mask_i * blend_mask_j
|
||||
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# scale of decoder output relative to input
|
||||
sf = self.spatial_scale_factor
|
||||
tile_size = self.tile_latent_min_size
|
||||
|
||||
# number of pixels to blend and to traverse between tiles
|
||||
blend_size = int(tile_size * self.tile_overlap_factor)
|
||||
traverse_size = tile_size - blend_size
|
||||
|
||||
# tiles index (up/left)
|
||||
ti = range(0, x.shape[-2], traverse_size)
|
||||
tj = range(0, x.shape[-1], traverse_size)
|
||||
|
||||
# mask for blending
|
||||
blend_masks = torch.stack(
|
||||
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
|
||||
)
|
||||
blend_masks = blend_masks.clamp(0, 1).to(x.device)
|
||||
|
||||
# output array
|
||||
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
|
||||
for i in ti:
|
||||
for j in tj:
|
||||
tile_in = x[..., i : i + tile_size, j : j + tile_size]
|
||||
# tile result
|
||||
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
|
||||
tile = self.decoder(tile_in)
|
||||
h, w = tile.shape[-2], tile.shape[-1]
|
||||
# blend tile result into output
|
||||
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
|
||||
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
|
||||
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
|
||||
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
||||
return out
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.encoder(x)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -162,10 +307,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
output = self.decoder(x)
|
||||
# Refer to the following discussion to know why this is needed.
|
||||
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
|
||||
output = output.mul_(2).sub_(1)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -184,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
enc = self.encode(sample).latents
|
||||
|
||||
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
|
||||
# as if we were storing the latents in an RGBA uint8 image.
|
||||
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
|
||||
unscaled_enc = self.unscale_latents(scaled_enc)
|
||||
|
||||
# unquantize latents back into [0, 1], then unscale latents back to their original range,
|
||||
# as if we were loading the latents from an RGBA uint8 image.
|
||||
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
||||
|
||||
dec = self.decode(unscaled_enc)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -497,8 +497,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
@@ -723,7 +723,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if "addition_embed_type" in self.config:
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
|
||||
@@ -14,9 +14,15 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
@@ -28,6 +34,8 @@ class LoRALinearLayer(nn.Module):
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
@@ -89,6 +97,51 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.device
|
||||
logger.info(f"Fusing LoRA weights for {self.__class__}")
|
||||
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
w_down = self.lora_layer.down.weight.data.float()
|
||||
|
||||
if self.lora_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((w_orig.shape))
|
||||
fused_weight = w_orig + fusion
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
self.lora_layer = None
|
||||
|
||||
# offload the up and down matrices to CPU to not blow the memory
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
logger.info(f"Unfusing LoRA weights for {self.__class__}")
|
||||
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.data.dtype, fused_weight.data.device
|
||||
|
||||
self.w_up = self.w_up.to(device=device, dtype=dtype)
|
||||
self.w_down = self.w_down.to(device, dtype=dtype)
|
||||
|
||||
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((fused_weight.shape))
|
||||
unfused_weight = fused_weight - fusion
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
@@ -107,11 +160,50 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def forward(self, x):
|
||||
def _fuse_lora(self):
|
||||
if self.lora_layer is None:
|
||||
return super().forward(x)
|
||||
return
|
||||
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.device
|
||||
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
w_down = self.lora_layer.down.weight.data.float()
|
||||
|
||||
if self.lora_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
self.lora_layer = None
|
||||
|
||||
# offload the up and down matrices to CPU to not blow the memory
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.device
|
||||
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, lora_scale: int = 1):
|
||||
if self.lora_layer is None:
|
||||
return super().forward(hidden_states)
|
||||
else:
|
||||
return super().forward(x) + self.lora_layer(x)
|
||||
return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states)
|
||||
|
||||
@@ -171,8 +171,8 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
@@ -88,6 +88,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
@@ -181,6 +182,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
|
||||
@@ -584,8 +584,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
@@ -965,6 +965,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_block_additional_residuals) > 0
|
||||
and sample.shape == down_block_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
@@ -280,8 +280,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
@@ -732,7 +732,8 @@ class EncoderTiny(nn.Module):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
||||
|
||||
else:
|
||||
x = self.layers(x)
|
||||
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
||||
x = self.layers(x.add(1).div(2))
|
||||
|
||||
return x
|
||||
|
||||
@@ -790,4 +791,5 @@ class DecoderTiny(nn.Module):
|
||||
else:
|
||||
x = self.layers(x)
|
||||
|
||||
return x
|
||||
# scale image from [0, 1] to [-1, 1] to match diffusers convention
|
||||
return x.mul(2).sub(1)
|
||||
|
||||
@@ -46,10 +46,12 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
@@ -82,6 +84,7 @@ else:
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
@@ -115,7 +118,7 @@ else:
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
|
||||
@@ -258,12 +258,45 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = (
|
||||
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
|
||||
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
)
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -402,12 +435,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -634,7 +662,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -644,6 +672,11 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -259,12 +259,45 @@ class AltDiffusionImg2ImgPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = (
|
||||
"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
|
||||
" instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
)
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -403,12 +436,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -567,14 +595,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -597,7 +618,10 @@ class AltDiffusionImg2ImgPipeline(
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
@@ -672,7 +696,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -682,6 +706,11 @@ class AltDiffusionImg2ImgPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
@@ -418,8 +418,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
@@ -436,9 +435,9 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
||||
[`~pipelines.AudioPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
)
|
||||
else:
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,977 @@
|
||||
# Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_librosa_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import scipy
|
||||
>>> import torch
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # define the prompts
|
||||
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
||||
>>> negative_prompt = "Low quality."
|
||||
|
||||
>>> # set the seed for generator
|
||||
>>> generator = torch.Generator("cuda").manual_seed(0)
|
||||
|
||||
>>> # run the generation
|
||||
>>> audio = pipe(
|
||||
... prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_inference_steps=200,
|
||||
... audio_length_in_s=10.0,
|
||||
... num_waveforms_per_prompt=3,
|
||||
... generator=generator,
|
||||
... ).audios
|
||||
|
||||
>>> # save the best audio sample (index 0) as a .wav file
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
inputs_embeds,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
# only last token for inputs_embeds if past is defined in kwargs
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
|
||||
return {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
|
||||
class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using AudioLDM2.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.ClapModel`]):
|
||||
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
|
||||
[CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
|
||||
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
|
||||
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
|
||||
rank generated waveforms against the text prompt by computing similarity scores.
|
||||
text_encoder_2 ([`~transformers.T5EncoderModel`]):
|
||||
Second frozen text-encoder. AudioLDM2 uses the encoder of
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
|
||||
projection_model ([`AudioLDM2ProjectionModel`]):
|
||||
A trained model used to linearly project the hidden-states from the first and second text encoder models
|
||||
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
|
||||
concatenated to give the input to the language model.
|
||||
language_model ([`~transformers.GPT2Model`]):
|
||||
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
|
||||
outputs from the two text encoders.
|
||||
tokenizer ([`~transformers.RobertaTokenizer`]):
|
||||
Tokenizer to tokenize text for the first frozen text-encoder.
|
||||
tokenizer_2 ([`~transformers.T5Tokenizer`]):
|
||||
Tokenizer to tokenize text for the second frozen text-encoder.
|
||||
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
||||
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: ClapModel,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
projection_model: AudioLDM2ProjectionModel,
|
||||
language_model: GPT2Model,
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
|
||||
feature_extractor: ClapFeatureExtractor,
|
||||
unet: AudioLDM2UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vocoder: SpeechT5HifiGan,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
projection_model=projection_model,
|
||||
language_model=language_model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = [
|
||||
self.text_encoder.text_model,
|
||||
self.text_encoder.text_projection,
|
||||
self.text_encoder_2,
|
||||
self.projection_model,
|
||||
self.language_model,
|
||||
self.unet,
|
||||
self.vae,
|
||||
self.vocoder,
|
||||
self.text_encoder,
|
||||
]
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def generate_language_model(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor = None,
|
||||
max_new_tokens: int = 8,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
|
||||
|
||||
Parameters:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_new_tokens (`int`):
|
||||
Number of new tokens to generate.
|
||||
model_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
|
||||
function of the model.
|
||||
|
||||
Return:
|
||||
`inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
# forward pass to get next hidden states
|
||||
output = self.language_model(**model_inputs, return_dict=True)
|
||||
|
||||
next_hidden_states = output.last_hidden_state
|
||||
|
||||
# Update the model input
|
||||
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
||||
|
||||
# Update generated hidden states, model inputs, and length for next step
|
||||
model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
|
||||
|
||||
return inputs_embeds[:, -max_new_tokens:, :]
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device (`torch.device`):
|
||||
torch device
|
||||
num_waveforms_per_prompt (`int`):
|
||||
number of waveforms that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
The number of new tokens to generate with the GPT2 language model.
|
||||
Returns:
|
||||
prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings from the Flan T5 model.
|
||||
attention_mask (`torch.LongTensor`):
|
||||
Attention mask to be applied to the `prompt_embeds`.
|
||||
generated_prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings generated from the GPT2 langauge model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import scipy
|
||||
>>> import torch
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Get text embedding vectors
|
||||
>>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
|
||||
... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
|
||||
... device="cuda",
|
||||
... do_classifier_free_guidance=True,
|
||||
... )
|
||||
|
||||
>>> # Pass text embeddings to pipeline for text-conditional audio generation
|
||||
>>> audio = pipe(
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... attention_mask=attention_mask,
|
||||
... generated_prompt_embeds=generated_prompt_embeds,
|
||||
... num_inference_steps=200,
|
||||
... audio_length_in_s=10.0,
|
||||
... ).audios[0]
|
||||
|
||||
>>> # save generated audio sample
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
||||
```"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds_list = []
|
||||
attention_mask_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
f"The following part of your input was truncated because {text_encoder.config.model_type} can "
|
||||
f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
prompt_embeds = text_encoder.get_text_features(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
prompt_embeds = prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
attention_mask = attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
attention_mask_list.append(attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=prompt_embeds_list[0],
|
||||
hidden_states_1=prompt_embeds_list[1],
|
||||
attention_mask=attention_mask_list[0],
|
||||
attention_mask_1=attention_mask_list[1],
|
||||
)
|
||||
projected_prompt_embeds = projection_output.hidden_states
|
||||
projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
generated_prompt_embeds = self.generate_language_model(
|
||||
projected_prompt_embeds,
|
||||
attention_mask=projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
attention_mask = (
|
||||
attention_mask.to(device=device)
|
||||
if attention_mask is not None
|
||||
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, hidden_size = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
|
||||
|
||||
# duplicate attention mask for each generation per prompt
|
||||
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
|
||||
# duplicate generated embeddings for each generation per prompt, using mps friendly method
|
||||
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
generated_prompt_embeds = generated_prompt_embeds.view(
|
||||
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
negative_attention_mask_list = []
|
||||
max_length = prompt_embeds.shape[1]
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length
|
||||
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||
else max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_input_ids = uncond_input.input_ids.to(device)
|
||||
negative_attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
negative_prompt_embeds = text_encoder.get_text_features(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
negative_attention_mask_list.append(negative_attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=negative_prompt_embeds_list[0],
|
||||
hidden_states_1=negative_prompt_embeds_list[1],
|
||||
attention_mask=negative_attention_mask_list[0],
|
||||
attention_mask_1=negative_attention_mask_list[1],
|
||||
)
|
||||
negative_projected_prompt_embeds = projection_output.hidden_states
|
||||
negative_projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
negative_generated_prompt_embeds = self.generate_language_model(
|
||||
negative_projected_prompt_embeds,
|
||||
attention_mask=negative_projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
negative_attention_mask = (
|
||||
negative_attention_mask.to(device=device)
|
||||
if negative_attention_mask is not None
|
||||
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
|
||||
dtype=self.language_model.dtype, device=device
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate unconditional attention mask for each generation per prompt
|
||||
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# duplicate unconditional generated embeddings for each generation per prompt
|
||||
seq_len = negative_generated_prompt_embeds.shape[1]
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
|
||||
batch_size * num_waveforms_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask])
|
||||
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
|
||||
|
||||
return prompt_embeds, attention_mask, generated_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
||||
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
||||
if mel_spectrogram.dim() == 4:
|
||||
mel_spectrogram = mel_spectrogram.squeeze(1)
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
||||
if not is_librosa_available():
|
||||
logger.info(
|
||||
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
||||
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
||||
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
||||
)
|
||||
return audio
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
||||
resampled_audio = librosa.resample(
|
||||
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
||||
)
|
||||
inputs["input_features"] = self.feature_extractor(
|
||||
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
||||
).input_features.type(dtype)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# compute the audio-text similarity score using the CLAP model
|
||||
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
||||
# sort by the highest matching generations per prompt
|
||||
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
||||
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
||||
return audio
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
generated_prompt_embeds=None,
|
||||
negative_generated_prompt_embeds=None,
|
||||
attention_mask=None,
|
||||
negative_attention_mask=None,
|
||||
):
|
||||
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
||||
if audio_length_in_s < min_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
||||
f"is {audio_length_in_s}."
|
||||
)
|
||||
|
||||
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
||||
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
||||
f"{self.vae_scale_factor}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
|
||||
raise ValueError(
|
||||
"Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
|
||||
"`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
|
||||
"both arguments are specified"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
|
||||
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
|
||||
f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
|
||||
f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
|
||||
)
|
||||
if (
|
||||
negative_attention_mask is not None
|
||||
and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
|
||||
):
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
num_inference_steps: int = 200,
|
||||
guidance_scale: float = 3.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_waveforms_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
||||
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
||||
The length of the generated audio sample in seconds.
|
||||
num_inference_steps (`int`, *optional*, defaults to 200):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
||||
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
|
||||
scoring is performed between the generated outputs and the text prompt. This scoring ranks the
|
||||
generated waveforms based on their cosine similarity with the text input in the joint text-audio
|
||||
embedding space.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
|
||||
be taken from the config of the model.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
||||
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
||||
model (LDM) output.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
||||
|
||||
height = int(audio_length_in_s / vocoder_upsample_factor)
|
||||
|
||||
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
||||
if height % self.vae_scale_factor != 0:
|
||||
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
||||
logger.info(
|
||||
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
||||
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
||||
f"denoising process."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds,
|
||||
attention_mask,
|
||||
negative_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
generated_prompt_embeds=generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds=negative_generated_prompt_embeds,
|
||||
attention_mask=attention_mask,
|
||||
negative_attention_mask=negative_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=generated_prompt_embeds,
|
||||
encoder_hidden_states_1=prompt_embeds,
|
||||
encoder_attention_mask_1=attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
mel_spectrogram = self.vae.decode(latents).sample
|
||||
else:
|
||||
return AudioPipelineOutput(audios=latents)
|
||||
|
||||
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
||||
|
||||
audio = audio[:, :original_waveform_length]
|
||||
|
||||
# 9. Automatic scoring
|
||||
if num_waveforms_per_prompt > 1 and prompt is not None:
|
||||
audio = self.score_waveforms(
|
||||
text=prompt,
|
||||
audio=audio,
|
||||
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if output_type == "np":
|
||||
audio = audio.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -17,10 +17,12 @@ import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import DIFFUSERS_CACHE
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
@@ -72,6 +74,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky", KandinskyImg2ImgCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -295,7 +298,29 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
>>> image = pipeline(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
@@ -303,6 +328,7 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
|
||||
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -535,7 +561,29 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
>>> image = pipeline(prompt, image).images[0]
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
@@ -543,6 +591,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
|
||||
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -776,7 +825,29 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
@@ -784,6 +855,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -17,6 +17,7 @@ else:
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
|
||||
@@ -23,11 +23,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
@@ -250,12 +251,43 @@ class StableDiffusionControlNetPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -394,12 +426,7 @@ class StableDiffusionControlNetPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -678,14 +705,7 @@ class StableDiffusionControlNetPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -849,7 +869,7 @@ class StableDiffusionControlNetPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -859,6 +879,11 @@ class StableDiffusionControlNetPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -276,12 +276,43 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -420,12 +451,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -750,22 +776,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
control_image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
@@ -935,7 +947,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -945,6 +957,12 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -24,11 +24,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
@@ -133,7 +134,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
@@ -316,6 +322,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
@@ -393,12 +402,43 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -537,12 +577,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -615,7 +650,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
@@ -863,31 +898,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
return outputs
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
|
||||
def prepare_mask_latents(
|
||||
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
@@ -950,16 +960,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
control_image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
@@ -989,14 +992,29 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
|
||||
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
|
||||
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
||||
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
||||
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
||||
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
||||
if passing latents directly it is not encoded again.
|
||||
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
|
||||
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
||||
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
||||
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
||||
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
||||
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
||||
1)`, or `(H, W)`.
|
||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
||||
batched for input to a single controlnet.
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. The
|
||||
dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed,
|
||||
`image` is resized according to them. If multiple ControlNets are specified in init, images must be
|
||||
passed as a list such that each element of the list can be correctly batched for input to a single
|
||||
controlnet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
@@ -1080,9 +1098,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
@@ -1137,7 +1152,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -1147,6 +1162,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
@@ -1184,9 +1204,13 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
assert False
|
||||
|
||||
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True
|
||||
)
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
_, _, height, width = init_image.shape
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -24,8 +25,8 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
||||
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@@ -101,7 +102,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLControlNetPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
||||
|
||||
@@ -111,6 +114,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -139,6 +143,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
||||
Whether the negative prompt embeddings shall always be set to 0. Also see the config of
|
||||
`stabilityai/stable-diffusion-xl-base-1-0`.
|
||||
add_watermarker (`bool`, *optional*):
|
||||
Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
|
||||
watermark output images. If not defined, it will default to `True` if the package is installed, otherwise
|
||||
no watermarker will be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -754,14 +765,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -788,6 +792,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -894,6 +901,22 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -1057,10 +1080,20 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
@@ -1074,15 +1107,22 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
controlnet_added_cond_kwargs = {
|
||||
"text_embeds": add_text_embeds.chunk(2)[1],
|
||||
"time_ids": add_time_ids.chunk(2)[1],
|
||||
}
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
controlnet_added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
@@ -1092,7 +1132,6 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
@@ -1100,7 +1139,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=guess_mode,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
added_cond_kwargs=controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -1144,13 +1183,19 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
@@ -1169,3 +1214,76 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,17 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
MusicLDMPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_musicldm import MusicLDMPipeline
|
||||
@@ -0,0 +1,607 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
ClapTextModelWithProjection,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_librosa_available, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import MusicLDMPipeline
|
||||
>>> import torch
|
||||
>>> import scipy
|
||||
|
||||
>>> repo_id = "cvssp/audioldm-s-full-v2"
|
||||
>>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
|
||||
>>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
|
||||
|
||||
>>> # save the audio sample as a .wav file
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class MusicLDMPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using MusicLDM.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.ClapModel`]):
|
||||
Frozen text-audio embedding model (`ClapTextModel`), specifically the
|
||||
[laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
A [`~transformers.RobertaTokenizer`] to tokenize text.
|
||||
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
||||
Feature extractor to compute mel-spectrograms from audio waveforms.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[ClapTextModelWithProjection, ClapModel],
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
feature_extractor: Optional[ClapFeatureExtractor],
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vocoder: SpeechT5HifiGan,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device (`torch.device`):
|
||||
torch device
|
||||
num_waveforms_per_prompt (`int`):
|
||||
number of waveforms that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLAP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder.get_text_features(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
|
||||
|
||||
(
|
||||
bs_embed,
|
||||
seq_len,
|
||||
) = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_input_ids = uncond_input.input_ids.to(device)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder.get_text_features(
|
||||
uncond_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
||||
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
||||
if mel_spectrogram.dim() == 4:
|
||||
mel_spectrogram = mel_spectrogram.squeeze(1)
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms
|
||||
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
||||
if not is_librosa_available():
|
||||
logger.info(
|
||||
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
||||
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
||||
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
||||
)
|
||||
return audio
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
||||
resampled_audio = librosa.resample(
|
||||
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
||||
)
|
||||
inputs["input_features"] = self.feature_extractor(
|
||||
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
||||
).input_features.type(dtype)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# compute the audio-text similarity score using the CLAP model
|
||||
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
||||
# sort by the highest matching generations per prompt
|
||||
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
||||
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
||||
return audio
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
||||
if audio_length_in_s < min_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
||||
f"is {audio_length_in_s}."
|
||||
)
|
||||
|
||||
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
||||
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
||||
f"{self.vae_scale_factor}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
num_inference_steps: int = 200,
|
||||
guidance_scale: float = 2.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_waveforms_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
||||
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
||||
The length of the generated audio sample in seconds.
|
||||
num_inference_steps (`int`, *optional*, defaults to 200):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 2.0):
|
||||
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
||||
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding
|
||||
model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a
|
||||
`[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs
|
||||
and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text
|
||||
input in the joint text-audio embedding space.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
||||
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
||||
model (LDM) output.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.AudioPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
||||
|
||||
height = int(audio_length_in_s / vocoder_upsample_factor)
|
||||
|
||||
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
||||
if height % self.vae_scale_factor != 0:
|
||||
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
||||
logger.info(
|
||||
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
||||
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
||||
f"denoising process."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=None,
|
||||
class_labels=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
mel_spectrogram = self.vae.decode(latents).sample
|
||||
else:
|
||||
return AudioPipelineOutput(audios=latents)
|
||||
|
||||
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
||||
|
||||
audio = audio[:, :original_waveform_length]
|
||||
|
||||
# 9. Automatic scoring
|
||||
if num_waveforms_per_prompt > 1 and prompt is not None:
|
||||
audio = self.score_waveforms(
|
||||
text=prompt,
|
||||
audio=audio,
|
||||
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if output_type == "np":
|
||||
audio = audio.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -23,9 +23,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PaintByExampleImageEncoder(CLIPPreTrainedModel):
|
||||
def __init__(self, config, proj_size=768):
|
||||
def __init__(self, config, proj_size=None):
|
||||
super().__init__(config)
|
||||
self.proj_size = proj_size
|
||||
self.proj_size = proj_size or getattr(config, "projection_dim", 768)
|
||||
|
||||
self.model = CLIPVisionModel(config)
|
||||
self.mapper = PaintByExampleMapper(config)
|
||||
|
||||
@@ -694,14 +694,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
|
||||
if pipeline_is_sequentially_offloaded and torch_device and torch.device(torch_device).type == "cuda":
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
|
||||
if pipeline_is_offloaded and torch_device and torch.device(torch_device).type == "cuda":
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
@@ -924,6 +924,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
@@ -940,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
from_flax=from_flax,
|
||||
use_safetensors=use_safetensors,
|
||||
use_onnx=use_onnx,
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
@@ -1010,8 +1012,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
# define init kwargs and make sure that optional component modules are filtered out
|
||||
init_kwargs = {
|
||||
k: init_dict.pop(k)
|
||||
for k in optional_kwargs
|
||||
if k in init_dict and k not in pipeline_class._optional_components
|
||||
}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint(
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
original_config_file: str = None,
|
||||
image_size: Optional[int] = None,
|
||||
prediction_type: str = None,
|
||||
@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
||||
|
||||
Args:
|
||||
checkpoint_path (`str`): Path to `.ckpt` file.
|
||||
checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
||||
inferred by looking for a key that only exists in SD2.0 models.
|
||||
@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors.torch import load_file as safe_load
|
||||
if isinstance(checkpoint_path_or_dict, str):
|
||||
if from_safetensors:
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
else:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
|
||||
elif isinstance(checkpoint_path_or_dict, dict):
|
||||
checkpoint = checkpoint_path_or_dict
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
image_size = 512
|
||||
|
||||
if controlnet is None and "control_stage_config" in original_config.model.params:
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -1387,8 +1392,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this.
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
else:
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
@@ -1588,16 +1594,34 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
@@ -1611,6 +1635,11 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
@@ -270,12 +270,43 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -414,12 +445,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
|
||||
def check_inputs(
|
||||
@@ -578,14 +604,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
source_prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -745,7 +764,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -753,9 +772,17 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
prompt_embeds=prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
source_prompt_embeds = self._encode_prompt(
|
||||
source_prompt_embeds_tuple = self.encode_prompt(
|
||||
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
|
||||
)
|
||||
if prompt_embeds_tuple[1] is not None:
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
else:
|
||||
prompt_embeds = prompt_embeds_tuple[0]
|
||||
if source_prompt_embeds_tuple[1] is not None:
|
||||
source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]])
|
||||
else:
|
||||
source_prompt_embeds = source_prompt_embeds_tuple[0]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
+369
-177
@@ -1,26 +1,34 @@
|
||||
from logging import getLogger
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging
|
||||
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ..pipeline_utils import ImagePipelineOutput
|
||||
from . import StableDiffusionUpscalePipeline
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
@@ -45,7 +53,17 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
vae: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
unet: OnnxRuntimeModel
|
||||
low_res_scheduler: DDPMScheduler
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
@@ -55,39 +73,296 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
tokenizer: Any,
|
||||
unet: OnnxRuntimeModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Any,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[OnnxRuntimeModel] = None,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
max_noise_level: int = 350,
|
||||
num_latent_channels=4,
|
||||
num_unet_input_channels=7,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
watermarker=None,
|
||||
max_noise_level=max_noise_level,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(
|
||||
max_noise_level=max_noise_level,
|
||||
num_latent_channels=num_latent_channels,
|
||||
num_unet_input_channels=num_unet_input_channels,
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, np.ndarray)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, np.ndarray):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
# check noise level
|
||||
if noise_level > self.config.max_noise_level:
|
||||
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
latents = generator.randn(*shape).astype(dtype)
|
||||
elif latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: Optional[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str],
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt to be encoded
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
r"""
|
||||
@@ -108,7 +383,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
noise_level TODO
|
||||
noise_level (`float`, defaults to 0.2):
|
||||
Deteremines the amount of noise to add to the initial image before performing upscaling.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
@@ -152,7 +428,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -162,16 +446,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
@@ -179,51 +463,55 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
image = preprocess(image).cpu().numpy()
|
||||
height, width = image.shape[2:]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.cpu()
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.num_latent_channels,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
generator,
|
||||
)
|
||||
image = image.astype(latents_dtype)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
noise_level = np.array([noise_level]).astype(np.int64)
|
||||
noise = generator.randn(*image.shape).astype(latents_dtype)
|
||||
|
||||
image = self.low_res_scheduler.add_noise(
|
||||
torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level)
|
||||
)
|
||||
image = image.numpy()
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = np.concatenate([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||
if self.num_latent_channels + num_channels_image != self.num_unet_input_channels:
|
||||
raise ValueError(
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||
f" {self.num_unet_input_channels} but received `num_channels_latents`: {self.num_latent_channels} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||
f" = {self.num_latent_channels + num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
@@ -248,8 +536,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level.astype(np.int64),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
class_labels=noise_level,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
@@ -259,8 +547,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
).prev_sample
|
||||
latents = latents.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
@@ -269,125 +558,28 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents.float())
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device,
|
||||
num_images_per_prompt: Optional[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str],
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# no positional arguments to text_encoder
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if hasattr(uncond_input, "attention_mask"):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@@ -260,12 +260,42 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -404,12 +434,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -634,7 +659,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -644,6 +669,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
+40
-9
@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import Attention
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -259,12 +259,43 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -403,12 +434,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -810,7 +836,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -819,6 +845,11 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -24,7 +24,7 @@ from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -144,12 +144,43 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -288,12 +319,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -499,14 +525,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
depth_map: Optional[torch.FloatTensor] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
@@ -638,7 +657,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -648,6 +667,11 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare depth mask
|
||||
depth_mask = self.prepare_depth_map(
|
||||
|
||||
@@ -445,12 +445,43 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -589,12 +620,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -970,7 +996,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
# 3. Encode input prompts
|
||||
(cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
|
||||
target_prompt_embeds = self._encode_prompt(
|
||||
target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt(
|
||||
target_prompt,
|
||||
device,
|
||||
num_maps_per_mask,
|
||||
@@ -979,8 +1005,13 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
prompt_embeds=target_prompt_embeds,
|
||||
negative_prompt_embeds=target_negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds])
|
||||
|
||||
source_prompt_embeds = self._encode_prompt(
|
||||
source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt(
|
||||
source_prompt,
|
||||
device,
|
||||
num_maps_per_mask,
|
||||
@@ -989,6 +1020,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
prompt_embeds=source_prompt_embeds,
|
||||
negative_prompt_embeds=source_negative_prompt_embeds,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
|
||||
@@ -1178,7 +1211,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
)
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -1187,6 +1220,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 6. Prepare timesteps
|
||||
self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -1410,7 +1448,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -1420,6 +1458,11 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess mask
|
||||
mask_image = preprocess_mask(mask_image, batch_size)
|
||||
|
||||
@@ -26,6 +26,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention import GatedSelfAttentionDense
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -234,12 +235,43 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -378,12 +410,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -654,7 +681,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -663,6 +690,11 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -23,7 +23,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -264,12 +264,43 @@ class StableDiffusionImg2ImgPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -408,12 +439,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -573,14 +599,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
@@ -603,7 +622,10 @@ class StableDiffusionImg2ImgPipeline(
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
@@ -678,7 +700,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -688,6 +710,11 @@ class StableDiffusionImg2ImgPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
@@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -63,7 +63,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
|
||||
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
|
||||
deprecate(
|
||||
"prepare_mask_and_masked_image",
|
||||
"0.30.0",
|
||||
deprecation_message,
|
||||
)
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
@@ -280,6 +285,9 @@ class StableDiffusionInpaintPipeline(
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
@@ -322,12 +330,43 @@ class StableDiffusionInpaintPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -466,12 +505,7 @@ class StableDiffusionInpaintPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -679,8 +713,8 @@ class StableDiffusionInpaintPipeline(
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
@@ -705,14 +739,20 @@ class StableDiffusionInpaintPipeline(
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`PIL.Image.Image`):
|
||||
`Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked
|
||||
out with `mask_image` and repainted according to `prompt`).
|
||||
mask_image (`PIL.Image.Image`):
|
||||
`Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted
|
||||
while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel
|
||||
(luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the
|
||||
expected shape would be `(B, H, W, 1)`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
|
||||
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
||||
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
||||
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
||||
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
||||
if passing latents directly it is not encoded again.
|
||||
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
||||
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
||||
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
||||
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
||||
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
||||
1)`, or `(H, W)`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
@@ -837,7 +877,7 @@ class StableDiffusionInpaintPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -847,6 +887,11 @@ class StableDiffusionInpaintPipeline(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -865,9 +910,14 @@ class StableDiffusionInpaintPipeline(
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True
|
||||
)
|
||||
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
|
||||
mask_condition = mask.clone()
|
||||
|
||||
# 6. Prepare latent variables
|
||||
|
||||
+39
-8
@@ -260,12 +260,43 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -404,12 +435,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -643,7 +669,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -653,6 +679,11 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess image and mask
|
||||
if not isinstance(image, torch.FloatTensor):
|
||||
|
||||
+2
-9
@@ -21,7 +21,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -147,14 +147,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float = 7.5,
|
||||
image_guidance_scale: float = 1.5,
|
||||
|
||||
@@ -24,7 +24,7 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
@@ -167,12 +167,43 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -311,12 +342,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -523,7 +549,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
raise ValueError("has to use guidance_scale")
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -532,6 +558,11 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device)
|
||||
|
||||
+2
-9
@@ -21,7 +21,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging, randn_tensor
|
||||
@@ -257,14 +257,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
] = None,
|
||||
image: PipelineImageInput = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
|
||||
@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -229,12 +230,43 @@ class StableDiffusionLDM3DPipeline(
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -373,12 +405,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -585,7 +612,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -594,6 +621,11 @@ class StableDiffusionLDM3DPipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
+40
-9
@@ -24,7 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...schedulers.scheduling_utils import SchedulerMixin
|
||||
from ...utils import logging, randn_tensor
|
||||
from ...utils import deprecate, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -171,12 +171,43 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -315,12 +346,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -679,7 +705,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -689,6 +715,11 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -148,12 +148,43 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
@@ -292,12 +323,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -566,7 +592,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -576,6 +602,11 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user