Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 813a1b2ee0 | |||
| a43b8574a9 | |||
| a2f0db52e3 | |||
| 92f6693b37 | |||
| 932897afa8 | |||
| c2940434d0 | |||
| 60ab8fad16 | |||
| d17240457f | |||
| 7512fc4df5 | |||
| 0c2f1ccc97 | |||
| 47f2d2c7be | |||
| af85591593 | |||
| 29f15673ed | |||
| 1037287e2b | |||
| 6ea95b7a90 | |||
| 0e0db625d0 | |||
| 1f948109b8 | |||
| 37cb819df5 | |||
| f64d52dbca | |||
| 4d897aaff5 | |||
| b1105269b7 | |||
| 5d28d2217f | |||
| 73bf620dec | |||
| c806f2fad6 | |||
| 18b7264bd0 | |||
| d82157b3ce | |||
| 93579650f8 | |||
| 16a056a7b5 | |||
| 6c6a246461 | |||
| 6bbee1048b | |||
| 2c60f7d14e | |||
| b6e0b016ce | |||
| 88735249da | |||
| 4191ddee11 | |||
| 2ab170499e | |||
| 914c513ee0 | |||
| d73e6ad050 | |||
| d0cf681a1f | |||
| dfec61f4b3 | |||
| 0ec7a02b6a | |||
| 626284f8d1 | |||
| 9800cc5ece | |||
| 541bb6ee63 | |||
| b76274cb53 | |||
| dc3e0ca59b | |||
| 6c314ad0ce | |||
| 946bb53c56 | |||
| ea311e6989 | |||
| 4c5718a09c | |||
| 2340ed629e | |||
| cfdfcf2018 |
@@ -41,7 +41,7 @@ Core library:
|
||||
- Schedulers: @williamberman and @patrickvonplaten
|
||||
- Pipelines: @patrickvonplaten and @sayakpaul
|
||||
- Training examples: @sayakpaul and @patrickvonplaten
|
||||
- Docs: @stevenliu and @yiyixu
|
||||
- Docs: @stevhliu and @yiyixuxu
|
||||
- JAX and MPS: @pcuenca
|
||||
- Audio: @sanchit-gandhi
|
||||
- General functionalities: @patrickvonplaten and @sayakpaul
|
||||
|
||||
@@ -67,6 +67,7 @@ jobs:
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -63,6 +63,7 @@ jobs:
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m pip install accelerate --upgrade
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m pip install transformers --upgrade
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
<a href="https://github.com/huggingface/diffusers/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/diffusers.svg">
|
||||
</a>
|
||||
<a href="https://pepy.tech/project/diffusers">
|
||||
<img alt="GitHub release" src="https://static.pepy.tech/badge/diffusers/month">
|
||||
</a>
|
||||
<a href="CODE_OF_CONDUCT.md">
|
||||
<img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-2.0-4baaaa.svg">
|
||||
</a>
|
||||
|
||||
@@ -102,6 +102,8 @@
|
||||
title: InstructPix2Pix Training
|
||||
- local: training/custom_diffusion
|
||||
title: Custom Diffusion
|
||||
- local: training/t2i_adapters
|
||||
title: T2I-Adapters
|
||||
title: Training
|
||||
- sections:
|
||||
- local: using-diffusers/other-modalities
|
||||
@@ -310,6 +312,8 @@
|
||||
title: Versatile Diffusion
|
||||
- local: api/pipelines/vq_diffusion
|
||||
title: VQ Diffusion
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Pipelines
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
|
||||
@@ -35,4 +35,12 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le
|
||||
- save_lora_weights
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
|
||||
## StableDiffusionXLInstructPix2PixPipeline
|
||||
[[autodoc]] StableDiffusionXLInstructPix2PixPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## StableDiffusionXLPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput
|
||||
@@ -20,7 +20,7 @@ The abstract from the paper is:
|
||||
|
||||
## Tips
|
||||
|
||||
- SDXL works especially well with images between 768 and 1024.
|
||||
- Most SDXL checkpoints work best with an image size of 1024x1024. Image sizes of 768x768 and 512x512 are also supported, but the results aren't as good. Anything below 512x512 is not recommended and likely won't for for default checkpoints like [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
|
||||
- SDXL can pass a different prompt for each of the text encoders it was trained on. We can even pass different parts of the same prompt to the text encoders.
|
||||
- SDXL output images can be improved by making use of a refiner model in an image-to-image setting.
|
||||
- SDXL offers `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to negatively condition the model on image resolution and cropping parameters.
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
# Würstchen
|
||||
|
||||
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9">
|
||||
|
||||
[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.*
|
||||
|
||||
## Würstchen v2 comes to Diffusers
|
||||
|
||||
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
|
||||
|
||||
- Higher resolution (1024x1024 up to 2048x2048)
|
||||
- Faster inference
|
||||
- Multi Aspect Resolution Sampling
|
||||
- Better quality
|
||||
|
||||
|
||||
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
|
||||
|
||||
- v2-base
|
||||
- v2-aesthetic
|
||||
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
|
||||
|
||||
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
|
||||
A comparison can be seen here:
|
||||
|
||||
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500>
|
||||
|
||||
## Text-to-Image Generation
|
||||
|
||||
For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
caption = "Anthropomorphic cat dressed as a fire fighter"
|
||||
images = pipe(
|
||||
caption,
|
||||
width=1024,
|
||||
height=1536,
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
prior_guidance_scale=4.0,
|
||||
num_images_per_prompt=2,
|
||||
).images
|
||||
```
|
||||
|
||||
For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look at the [paper](https://huggingface.co/papers/2306.00637).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.float16
|
||||
num_images_per_prompt = 2
|
||||
|
||||
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
|
||||
"warp-ai/wuerstchen-prior", torch_dtype=dtype
|
||||
).to(device)
|
||||
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
|
||||
"warp-ai/wuerstchen", torch_dtype=dtype
|
||||
).to(device)
|
||||
|
||||
caption = "Anthropomorphic cat dressed as a fire fighter"
|
||||
negative_prompt = ""
|
||||
|
||||
prior_output = prior_pipeline(
|
||||
prompt=caption,
|
||||
height=1024,
|
||||
width=1536,
|
||||
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
negative_prompt=negative_prompt,
|
||||
guidance_scale=4.0,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
decoder_output = decoder_pipeline(
|
||||
image_embeddings=prior_output.image_embeddings,
|
||||
prompt=caption,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
guidance_scale=0.0,
|
||||
output_type="pil",
|
||||
).images
|
||||
```
|
||||
|
||||
## Speed-Up Inference
|
||||
You can make use of `torch.compile` function and gain a speed-up of about 2-3x:
|
||||
|
||||
```python
|
||||
pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
|
||||
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Due to the high compression employed by Würstchen, generations can lack a good amount
|
||||
of detail. To our human eye, this is especially noticeable in faces, hands etc.
|
||||
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
|
||||
after 1024x1024 is 1152x1152
|
||||
- The model lacks the ability to render correct text in images
|
||||
- The model often does not achieve photorealism
|
||||
- Difficult compositional prompts are hard for the model
|
||||
|
||||
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
|
||||
|
||||
## WuerschenPipeline
|
||||
|
||||
[[autodoc]] WuerstchenCombinedPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WuerstchenPriorPipeline
|
||||
|
||||
[[autodoc]] WuerstchenDecoderPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WuerstchenPriorPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput
|
||||
|
||||
## WuerstchenDecoderPipeline
|
||||
|
||||
[[autodoc]] WuerstchenDecoderPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -2,30 +2,26 @@
|
||||
|
||||
Utility and helper functions for working with 🤗 Diffusers.
|
||||
|
||||
## randn_tensor
|
||||
|
||||
[[autodoc]] diffusers.utils.randn_tensor
|
||||
|
||||
## numpy_to_pil
|
||||
|
||||
[[autodoc]] utils.pil_utils.numpy_to_pil
|
||||
[[autodoc]] utils.numpy_to_pil
|
||||
|
||||
## pt_to_pil
|
||||
|
||||
[[autodoc]] utils.pil_utils.pt_to_pil
|
||||
[[autodoc]] utils.pt_to_pil
|
||||
|
||||
## load_image
|
||||
|
||||
[[autodoc]] utils.testing_utils.load_image
|
||||
[[autodoc]] utils.load_image
|
||||
|
||||
## export_to_gif
|
||||
|
||||
[[autodoc]] utils.testing_utils.export_to_gif
|
||||
[[autodoc]] utils.export_to_gif
|
||||
|
||||
## export_to_video
|
||||
|
||||
[[autodoc]] utils.testing_utils.export_to_video
|
||||
[[autodoc]] utils.export_to_video
|
||||
|
||||
## make_image_grid
|
||||
|
||||
[[autodoc]] utils.pil_utils.make_image_grid
|
||||
[[autodoc]] utils.pil_utils.make_image_grid
|
||||
|
||||
@@ -301,6 +301,42 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] on a pipeline to m
|
||||
|
||||
To undo `fuse_lora`, call [`~diffusers.loaders.LoraLoaderMixin.unfuse_lora`] on a pipeline.
|
||||
|
||||
## Working with different LoRA scales when using LoRA fusion
|
||||
|
||||
If you need to use `scale` when working with `fuse_lora()` to control the influence of the LoRA parameters on the outputs, you should specify `lora_scale` within `fuse_lora()`. Passing the `scale` parameter to `cross_attention_kwargs` when you call the pipeline won't work.
|
||||
|
||||
To use a different `lora_scale` with `fuse_lora()`, you should first call `unfuse_lora()` on the corresponding pipeline and call `fuse_lora()` again with the expected `lora_scale`.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
# This uses a default `lora_scale` of 1.0.
|
||||
pipe.fuse_lora()
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images_fusion = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
# To work with a different `lora_scale`, first reverse the effects of `fuse_lora()`.
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Then proceed as follows.
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.fuse_lora(lora_scale=0.5)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images_fusion = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
```
|
||||
|
||||
## Supporting different LoRA checkpoints from Diffusers
|
||||
|
||||
🤗 Diffusers supports loading checkpoints from popular LoRA trainers such as [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). In this section, we outline the current API's details and limitations.
|
||||
|
||||
@@ -34,13 +34,16 @@ If you feel like another important example should exist, we are more than happy
|
||||
Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
|
||||
|
||||
- [Unconditional Training](./unconditional_training)
|
||||
- [Text-to-Image Training](./text2image)
|
||||
- [Text-to-Image Training](./text2image)<sup>*</sup>
|
||||
- [Text Inversion](./text_inversion)
|
||||
- [Dreambooth](./dreambooth)
|
||||
- [LoRA Support](./lora)
|
||||
- [ControlNet](./controlnet)
|
||||
- [InstructPix2Pix](./instructpix2pix)
|
||||
- [Dreambooth](./dreambooth)<sup>*</sup>
|
||||
- [LoRA Support](./lora)<sup>*</sup>
|
||||
- [ControlNet](./controlnet)<sup>*</sup>
|
||||
- [InstructPix2Pix](./instructpix2pix)<sup>*</sup>
|
||||
- [Custom Diffusion](./custom_diffusion)
|
||||
- [T2I-Adapters](./t2i_adapters)<sup>*</sup>
|
||||
|
||||
<sup>*</sup>: Supports [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive.
|
||||
|
||||
@@ -54,6 +57,7 @@ If possible, please [install xFormers](../optimization/xformers) for memory effi
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | - |
|
||||
| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |
|
||||
| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |
|
||||
| [**T2I Adapters**](./t2i_adapters) | ✅ | ✅ | - |
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# T2I-Adapters for Stable Diffusion XL (SDXL)
|
||||
|
||||
The `train_t2i_adapter_sdxl.py` script (as shown below) shows how to implement the [T2I-Adapter training procedure](https://hf.co/papers/2302.08453) for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/t2i_adapter` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sdxl.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Circle filling dataset
|
||||
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
||||
|
||||
## Training
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch train_t2i_adapter_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--validation_steps=100 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--seed=42 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
adapter_path = "path to adapter"
|
||||
|
||||
adapter = T2IAdapter.from_pretrained(adapter_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
base_model_path, adapter=adapter, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# speed up diffusion process with faster scheduler and memory optimization
|
||||
pipe.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)
|
||||
# remove following line if xformers is not installed or when using Torch 2.0.
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
# memory optimization.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png")
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, image=control_image
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
@@ -116,7 +116,7 @@ mask_image_arr[mask_image_arr < 0.5] = 0
|
||||
mask_image_arr[mask_image_arr >= 0.5] = 1
|
||||
|
||||
# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
|
||||
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image_arr + mask_image_arr * repainted_image_arr
|
||||
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
|
||||
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
|
||||
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
|
||||
```
|
||||
|
||||
@@ -28,7 +28,7 @@ This is why it's important to understand how to control sources of randomness in
|
||||
|
||||
## Control randomness
|
||||
|
||||
During inference, pipelines rely heavily on random sampling operations which include creating the
|
||||
During inference, pipelines rely heavily on random sampling operations which include creating the
|
||||
Gaussian noise tensors to denoise and adding noise to the scheduling step.
|
||||
|
||||
Take a look at the tensor values in the [`DDIMPipeline`] after two inference steps:
|
||||
@@ -47,7 +47,7 @@ image = ddim(num_inference_steps=2, output_type="np").images
|
||||
print(np.abs(image).sum())
|
||||
```
|
||||
|
||||
Running the code above prints one value, but if you run it again you get a different value. What is going on here?
|
||||
Running the code above prints one value, but if you run it again you get a different value. What is going on here?
|
||||
|
||||
Every time the pipeline is run, [`torch.randn`](https://pytorch.org/docs/stable/generated/torch.randn.html) uses a different random seed to create Gaussian noise which is denoised stepwise. This leads to a different result each time it is run, which is great for diffusion pipelines since it generates a different random image each time.
|
||||
|
||||
@@ -81,16 +81,16 @@ If you run this code example on your specific hardware and PyTorch version, you
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 It might be a bit unintuitive at first to pass `Generator` objects to the pipeline instead of
|
||||
just integer values representing the seed, but this is the recommended design when dealing with
|
||||
probabilistic models in PyTorch as `Generator`'s are *random states* that can be
|
||||
💡 It might be a bit unintuitive at first to pass `Generator` objects to the pipeline instead of
|
||||
just integer values representing the seed, but this is the recommended design when dealing with
|
||||
probabilistic models in PyTorch as `Generator`'s are *random states* that can be
|
||||
passed to multiple pipelines in a sequence.
|
||||
|
||||
</Tip>
|
||||
|
||||
### GPU
|
||||
|
||||
Writing a reproducible pipeline on a GPU is a bit trickier, and full reproducibility across different hardware is not guaranteed because matrix multiplication - which diffusion pipelines require a lot of - is less deterministic on a GPU than a CPU. For example, if you run the same code example above on a GPU:
|
||||
Writing a reproducible pipeline on a GPU is a bit trickier, and full reproducibility across different hardware is not guaranteed because matrix multiplication - which diffusion pipelines require a lot of - is less deterministic on a GPU than a CPU. For example, if you run the same code example above on a GPU:
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -113,7 +113,7 @@ print(np.abs(image).sum())
|
||||
|
||||
The result is not the same even though you're using an identical seed because the GPU uses a different random number generator than the CPU.
|
||||
|
||||
To circumvent this problem, 🧨 Diffusers has a [`~diffusers.utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The `randn_tensor` function is used everywhere inside the pipeline, allowing the user to **always** pass a CPU `Generator` even if the pipeline is run on a GPU.
|
||||
To circumvent this problem, 🧨 Diffusers has a [`~diffusers.utils.torch_utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The `randn_tensor` function is used everywhere inside the pipeline, allowing the user to **always** pass a CPU `Generator` even if the pipeline is run on a GPU.
|
||||
|
||||
You'll see the results are much closer now!
|
||||
|
||||
@@ -139,14 +139,14 @@ print(np.abs(image).sum())
|
||||
<Tip>
|
||||
|
||||
💡 If reproducibility is important, we recommend always passing a CPU generator.
|
||||
The performance loss is often neglectable, and you'll generate much more similar
|
||||
The performance loss is often neglectable, and you'll generate much more similar
|
||||
values than if the pipeline had been run on a GPU.
|
||||
|
||||
</Tip>
|
||||
|
||||
Finally, for more complex pipelines such as [`UnCLIPPipeline`], these are often extremely
|
||||
susceptible to precision error propagation. Don't expect similar results across
|
||||
different GPU hardware or PyTorch versions. In this case, you'll need to run
|
||||
Finally, for more complex pipelines such as [`UnCLIPPipeline`], these are often extremely
|
||||
susceptible to precision error propagation. Don't expect similar results across
|
||||
different GPU hardware or PyTorch versions. In this case, you'll need to run
|
||||
exactly the same hardware and PyTorch version for full reproducibility.
|
||||
|
||||
## Deterministic algorithms
|
||||
|
||||
@@ -61,7 +61,7 @@ refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
|
||||
## Text-to-image
|
||||
|
||||
For text-to-image, pass a text prompt:
|
||||
For text-to-image, pass a text prompt. By default, SDXL generates a 1024x1024 image for the best results. You can try setting the `height` and `width` parameters to 768x768 or 512x512, but anything below 512x512 is not likely to work.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
|
||||
@@ -19,10 +19,8 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
randn_tensor,
|
||||
)
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
def preprocess(image, w, h):
|
||||
|
||||
@@ -19,11 +19,8 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
randn_tensor,
|
||||
)
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
|
||||
@@ -20,7 +20,7 @@ from torchvision import transforms
|
||||
|
||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import randn_tensor
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
trans = transforms.Compose(
|
||||
|
||||
@@ -21,8 +21,8 @@ from diffusers.utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
@@ -30,9 +30,9 @@ from diffusers.utils import (
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
@@ -1138,7 +1138,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffuser.utils.torch_utils import randn_tensor
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@@ -30,7 +31,6 @@ from diffusers.schedulers import EulerAncestralDiscreteScheduler, KarrasDiffusio
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ from diffusers.utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffuser.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
@@ -19,7 +20,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import PIL.Image
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from diffuser.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from pycuda.tools import make_default_context
|
||||
from transformers import CLIPTokenizer
|
||||
@@ -23,7 +24,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,9 +16,9 @@ from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -17,9 +17,9 @@ from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -16,9 +16,9 @@ from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -11,7 +11,8 @@ from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import is_compiled_module, logging, randn_tensor
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -31,9 +31,9 @@ from diffusers.utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -10,7 +10,8 @@ from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -33,8 +33,8 @@ from diffusers.utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -15,7 +15,8 @@ from diffusers.models.unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -700,7 +701,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 10.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -8,7 +8,8 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
from diffusers.models import PriorTransformer
|
||||
from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
|
||||
from diffusers.schedulers import UnCLIPScheduler
|
||||
from diffusers.utils import logging, randn_tensor
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -19,7 +19,8 @@ from diffusers import (
|
||||
UNet2DModel,
|
||||
)
|
||||
from diffusers.pipelines.unclip import UnCLIPTextProjModel
|
||||
from diffusers.utils import is_accelerate_available, logging, randn_tensor
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -15,7 +15,8 @@ from diffusers import (
|
||||
UNet2DModel,
|
||||
)
|
||||
from diffusers.pipelines.unclip import UnCLIPTextProjModel
|
||||
from diffusers.utils import is_accelerate_available, logging, randn_tensor
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -785,16 +785,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -840,16 +840,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -920,12 +920,13 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -894,27 +894,28 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -798,31 +798,32 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -485,14 +485,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -63,6 +63,7 @@ DATASET_NAME_MAPPING = {
|
||||
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
|
||||
}
|
||||
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
@@ -100,6 +101,16 @@ def parse_args():
|
||||
default=None,
|
||||
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_precision",
|
||||
type="choice",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
default="fp32",
|
||||
help=(
|
||||
"The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
|
||||
" to this problem, and this flag allows you to use mixed precision to stabilize training."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -517,14 +528,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
@@ -878,7 +890,7 @@ def main():
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
@@ -1010,16 +1010,17 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
i = len(weights) - 1
|
||||
if accelerator.is_main_process:
|
||||
i = len(weights) - 1
|
||||
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
while len(weights) > 0:
|
||||
weights.pop()
|
||||
model = models[i]
|
||||
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
sub_dir = "controlnet"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
i -= 1
|
||||
i -= 1
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
|
||||
@@ -552,14 +552,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
+7
-6
@@ -313,14 +313,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
We don't yet support training T2I-Adapters on Stable Diffusion yet. For training T2I-Adapters on Stable Diffusion XL, refer [here](./README_sdxl.md).
|
||||
@@ -0,0 +1,131 @@
|
||||
# T2I-Adapter training example for Stable Diffusion XL (SDXL)
|
||||
|
||||
The `train_t2i_adapter_sdxl.py` script shows how to implement the [T2I-Adapter training procedure](https://hf.co/papers/2302.08453) for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/t2i_adapter` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sdxl.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Circle filling dataset
|
||||
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
||||
|
||||
## Training
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch train_t2i_adapter_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--validation_steps=100 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--seed=42 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
adapter_path = "path to adapter"
|
||||
|
||||
adapter = T2IAdapter.from_pretrained(adapter_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
base_model_path, adapter=adapter, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# speed up diffusion process with faster scheduler and memory optimization
|
||||
pipe.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)
|
||||
# remove following line if xformers is not installed or when using Torch 2.0.
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
# memory optimization.
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png")
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, image=control_image
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
@@ -0,0 +1,8 @@
|
||||
transformers>=4.25.1
|
||||
accelerate>=0.16.0
|
||||
safetensors
|
||||
datasets
|
||||
torchvision
|
||||
ftfy
|
||||
tensorboard
|
||||
wandb
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1528,6 +1528,25 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
def test_t2i_adapter_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/t2i_adapter/train_t2i_adapter_sdxl.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -629,14 +629,15 @@ def main():
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -669,31 +669,32 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -651,14 +651,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -309,14 +309,15 @@ def main(args):
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
if accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
|
||||
@@ -154,6 +154,7 @@ if __name__ == "__main__":
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
config_files=args.config_files,
|
||||
image_size=args.image_size,
|
||||
prediction_type=args.prediction_type,
|
||||
model_type=args.pipeline_type,
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
from vqgan import VQModel
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
|
||||
|
||||
|
||||
model_path = "models/"
|
||||
device = "cpu"
|
||||
|
||||
paella_vqmodel = VQModel()
|
||||
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
|
||||
paella_vqmodel.load_state_dict(state_dict)
|
||||
|
||||
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
|
||||
state_dict.pop("vquantizer.codebook.weight")
|
||||
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
|
||||
vqmodel.load_state_dict(state_dict)
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# Generator
|
||||
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
|
||||
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
|
||||
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
deocder = WuerstchenDiffNeXt()
|
||||
deocder.load_state_dict(state_dict)
|
||||
|
||||
# Prior
|
||||
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
|
||||
prior_model.load_state_dict(state_dict)
|
||||
|
||||
# scheduler
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = WuerstchenPriorPipeline(
|
||||
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
|
||||
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
|
||||
|
||||
decoder_pipeline = WuerstchenDecoderPipeline(
|
||||
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
|
||||
|
||||
# Wuerstchen pipeline
|
||||
wuerstchen_pipeline = WuerstchenCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=gen_text_encoder,
|
||||
tokenizer=gen_tokenizer,
|
||||
decoder=deocder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_text_encoder=text_encoder,
|
||||
prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
)
|
||||
wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
|
||||
@@ -244,7 +244,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.21.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
+613
-208
@@ -1,13 +1,12 @@
|
||||
__version__ = "0.21.0.dev0"
|
||||
__version__ = "0.21.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
@@ -15,268 +14,363 @@ from .utils import (
|
||||
is_torch_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
# Lazy Import based on
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
|
||||
|
||||
# When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
|
||||
# and is used to defer the actual importing for when the objects are requested.
|
||||
# This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
"is_flax_available",
|
||||
"is_inflect_available",
|
||||
"is_invisible_watermark_available",
|
||||
"is_k_diffusion_available",
|
||||
"is_k_diffusion_version",
|
||||
"is_librosa_available",
|
||||
"is_note_seq_available",
|
||||
"is_onnx_available",
|
||||
"is_scipy_available",
|
||||
"is_torch_available",
|
||||
"is_torchsde_available",
|
||||
"is_transformers_available",
|
||||
"is_transformers_version",
|
||||
"is_unidecode_available",
|
||||
"logging",
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_onnx_objects import * # noqa F403
|
||||
from .utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_onnx_objects"] = [
|
||||
name for name in dir(dummy_onnx_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .pipelines import OnnxRuntimeModel
|
||||
_import_structure["pipelines"].extend(["OnnxRuntimeModel"])
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
from .utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
MultiAdapter,
|
||||
PriorTransformer,
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
VQModel,
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderTiny",
|
||||
"ControlNetModel",
|
||||
"ModelMixin",
|
||||
"MultiAdapter",
|
||||
"PriorTransformer",
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
"UNet1DModel",
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
"UNet3DConditionModel",
|
||||
"VQModel",
|
||||
]
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
_import_structure["optimization"] = [
|
||||
"get_constant_schedule",
|
||||
"get_constant_schedule_with_warmup",
|
||||
"get_cosine_schedule_with_warmup",
|
||||
"get_cosine_with_hard_restarts_schedule_with_warmup",
|
||||
"get_linear_schedule_with_warmup",
|
||||
"get_polynomial_decay_schedule_with_warmup",
|
||||
"get_scheduler",
|
||||
]
|
||||
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AudioPipelineOutput",
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
"ConsistencyModelPipeline",
|
||||
"DanceDiffusionPipeline",
|
||||
"DDIMPipeline",
|
||||
"DDPMPipeline",
|
||||
"DiffusionPipeline",
|
||||
"DiTPipeline",
|
||||
"ImagePipelineOutput",
|
||||
"KarrasVePipeline",
|
||||
"LDMPipeline",
|
||||
"LDMSuperResolutionPipeline",
|
||||
"PNDMPipeline",
|
||||
"RePaintPipeline",
|
||||
"ScoreSdeVePipeline",
|
||||
]
|
||||
)
|
||||
from .pipelines import (
|
||||
AudioPipelineOutput,
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
CLIPImageProjection,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
DiffusionPipeline,
|
||||
DiTPipeline,
|
||||
ImagePipelineOutput,
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"CMStochasticIterativeScheduler",
|
||||
"DDIMInverseScheduler",
|
||||
"DDIMParallelScheduler",
|
||||
"DDIMScheduler",
|
||||
"DDPMParallelScheduler",
|
||||
"DDPMScheduler",
|
||||
"DDPMWuerstchenScheduler",
|
||||
"DEISMultistepScheduler",
|
||||
"DPMSolverMultistepInverseScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
"DPMSolverSinglestepScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SchedulerMixin",
|
||||
"ScoreSdeVeScheduler",
|
||||
"UnCLIPScheduler",
|
||||
"UniPCMultistepScheduler",
|
||||
"VQDiffusionScheduler",
|
||||
]
|
||||
)
|
||||
from .schedulers import (
|
||||
CMStochasticIterativeScheduler,
|
||||
DDIMInverseScheduler,
|
||||
DDIMParallelScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMParallelScheduler,
|
||||
DDPMScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverMultistepInverseScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_scipy_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_scipy_objects"] = [
|
||||
name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .schedulers import LMSDiscreteScheduler
|
||||
_import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_torchsde_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_torchsde_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_torchsde_objects"] = [
|
||||
name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .schedulers import DPMSolverSDEScheduler
|
||||
_import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
ImageTextPipelineOutput,
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineSafe,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
UnCLIPImageVariationPipeline,
|
||||
UnCLIPPipeline,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VQDiffusionPipeline,
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AltDiffusionImg2ImgPipeline",
|
||||
"AltDiffusionPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
"ImageTextPipelineOutput",
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"LDMTextToImagePipeline",
|
||||
"MusicLDMPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPipelineSafe",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"UnCLIPImageVariationPipeline",
|
||||
"UnCLIPPipeline",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .pipelines import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
||||
from .utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_librosa_objects"] = [
|
||||
name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
from .pipelines import AudioDiffusionPipeline, Mel
|
||||
_import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
|
||||
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
from .pipelines import SpectrogramDiffusionPipeline
|
||||
_import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_objects import * # noqa F403
|
||||
from .utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_flax_objects"] = [
|
||||
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxScoreSdeVeScheduler,
|
||||
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"FlaxDDIMScheduler",
|
||||
"FlaxDDPMScheduler",
|
||||
"FlaxDPMSolverMultistepScheduler",
|
||||
"FlaxKarrasVeScheduler",
|
||||
"FlaxLMSDiscreteScheduler",
|
||||
"FlaxPNDMScheduler",
|
||||
"FlaxSchedulerMixin",
|
||||
"FlaxScoreSdeVeScheduler",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -284,19 +378,330 @@ try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
from .utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_flax_and_transformers_objects"] = [
|
||||
name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
from .pipelines import (
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionControlNetPipeline",
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_note_seq_objects import * # noqa F403
|
||||
from .utils import dummy_note_seq_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_note_seq_objects"] = [
|
||||
name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
|
||||
else:
|
||||
from .pipelines import MidiProcessor
|
||||
_import_structure["pipelines"].extend(["MidiProcessor"])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_utils import ConfigMixin
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_onnx_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
MultiAdapter,
|
||||
PriorTransformer,
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
)
|
||||
from .pipelines import (
|
||||
AudioPipelineOutput,
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
CLIPImageProjection,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
DiffusionPipeline,
|
||||
DiTPipeline,
|
||||
ImagePipelineOutput,
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
CMStochasticIterativeScheduler,
|
||||
DDIMInverseScheduler,
|
||||
DDIMParallelScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMParallelScheduler,
|
||||
DDPMScheduler,
|
||||
DDPMWuerstchenScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverMultistepInverseScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||
else:
|
||||
from .schedulers import LMSDiscreteScheduler
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_torchsde_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
||||
else:
|
||||
from .schedulers import DPMSolverSDEScheduler
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
ImageTextPipelineOutput,
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineSafe,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
UnCLIPImageVariationPipeline,
|
||||
UnCLIPPipeline,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VQDiffusionPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import SpectrogramDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxScoreSdeVeScheduler,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import (
|
||||
FlaxStableDiffusionControlNetPipeline,
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_note_seq_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import MidiProcessor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={"__version__": __version__},
|
||||
)
|
||||
|
||||
@@ -18,8 +18,8 @@ import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...utils import randn_tensor
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
@@ -76,7 +76,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
return x_in * self.stds[key] + self.means[key]
|
||||
|
||||
def to_torch(self, x_in):
|
||||
if type(x_in) is dict:
|
||||
if isinstance(x_in, dict):
|
||||
return {k: self.to_torch(v) for k, v in x_in.items()}
|
||||
elif torch.is_tensor(x_in):
|
||||
return x_in.to(self.unet.device)
|
||||
|
||||
+573
-394
File diff suppressed because it is too large
Load Diff
@@ -12,27 +12,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_flax_available, is_torch_available
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .modeling_utils import ModelMixin
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .vq_model import VQModel
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .modeling_utils import ModelMixin
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
||||
@@ -90,6 +90,8 @@ class MultiAdapter(ModelMixin):
|
||||
features = adapter(x)
|
||||
if accume_state is None:
|
||||
accume_state = features
|
||||
for i in range(len(accume_state)):
|
||||
accume_state[i] = w * accume_state[i]
|
||||
else:
|
||||
for i in range(len(features)):
|
||||
accume_state[i] += w * features[i]
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
|
||||
@@ -18,8 +18,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging, maybe_allow_in_graph
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .lora import LoRACompatibleLinear, LoRALinearLayer
|
||||
|
||||
|
||||
@@ -188,7 +189,7 @@ class Attention(nn.Module):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
||||
raise NotImplementedError(
|
||||
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
||||
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
|
||||
)
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import apply_forward_hook
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .autoencoder_kl import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
@@ -19,7 +19,8 @@ import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
|
||||
@@ -19,7 +19,8 @@ from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
||||
|
||||
|
||||
@@ -19,8 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin
|
||||
from ..models.lora import LoRACompatibleConv
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -81,7 +80,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
@@ -97,7 +96,6 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}")
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
@@ -109,9 +107,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(
|
||||
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin
|
||||
):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -251,7 +247,7 @@ class ControlNetModel(
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = LoRACompatibleConv(
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
@@ -723,7 +719,6 @@ class ControlNetModel(
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
print(f"t_emb: {t_emb[0, :3]}")
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
@@ -731,8 +726,6 @@ class ControlNetModel(
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
print(f"emb: {emb[0, :3]}")
|
||||
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
@@ -771,7 +764,6 @@ class ControlNetModel(
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}")
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
@@ -18,7 +18,6 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..models.lora import LoRACompatibleLinear
|
||||
from .activations import get_activation
|
||||
|
||||
|
||||
@@ -167,10 +166,10 @@ class TimestepEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False)
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
@@ -180,7 +179,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -40,17 +40,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
rank=4,
|
||||
network_alpha=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
# initial_weight=None,
|
||||
# initial_bias=None,
|
||||
):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
@@ -62,10 +52,6 @@ class LoRALinearLayer(nn.Module):
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
|
||||
# # Control-LoRA specific.
|
||||
# self.initial_weight = initial_weight
|
||||
# self.initial_bias = initial_bias
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
@@ -80,32 +66,11 @@ class LoRALinearLayer(nn.Module):
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
# else:
|
||||
# initial_weight = self.initial_weight
|
||||
# if initial_weight.device != hidden_states.device:
|
||||
# initial_weight = initial_weight.to(hidden_states.device)
|
||||
# return torch.nn.functional.linear(
|
||||
# hidden_states.to(dtype),
|
||||
# initial_weight
|
||||
# + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1)))
|
||||
# .reshape(self.initial_weight.shape)
|
||||
# .type(orig_dtype),
|
||||
# self.initial_bias,
|
||||
# )
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
rank=4,
|
||||
kernel_size=(1, 1),
|
||||
stride=(1, 1),
|
||||
padding=0,
|
||||
network_alpha=None,
|
||||
# initial_weight=None,
|
||||
# initial_bias=None,
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -119,13 +84,6 @@ class LoRAConv2dLayer(nn.Module):
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
# # Control-LoRA specific.
|
||||
# self.initial_weight = initial_weight
|
||||
# self.initial_bias = initial_bias
|
||||
# self.stride = stride
|
||||
# self.kernel_size = kernel_size
|
||||
# self.padding = padding
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
@@ -140,20 +98,6 @@ class LoRAConv2dLayer(nn.Module):
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
# else:
|
||||
# initial_weight = self.initial_weight
|
||||
# if initial_weight.device != hidden_states.device:
|
||||
# initial_weight = initial_weight.to(hidden_states.device)
|
||||
# return torch.nn.functional.conv2d(
|
||||
# hidden_states,
|
||||
# initial_weight
|
||||
# + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1)))
|
||||
# .reshape(self.initial_weight.shape)
|
||||
# .type(orig_dtype),
|
||||
# self.initial_bias,
|
||||
# self.stride,
|
||||
# self.padding,
|
||||
# )
|
||||
|
||||
|
||||
class LoRACompatibleConv(nn.Conv2d):
|
||||
|
||||
@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
unexpected_keys = []
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
unexpected_keys = []
|
||||
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
)
|
||||
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
|
||||
@@ -18,7 +18,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import BaseOutput, is_torch_version, randn_tensor
|
||||
from ..utils import BaseOutput, is_torch_version
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -52,7 +53,7 @@ class Encoder(nn.Module):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=3,
|
||||
|
||||
@@ -18,7 +18,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
|
||||
@@ -132,7 +133,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
quant, _, _ = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant2 = self.post_quant_conv(quant)
|
||||
|
||||
+394
-135
@@ -1,5 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
@@ -10,186 +14,441 @@ from ..utils import (
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["auto_pipeline"] = [
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
_import_structure["ddpm"] = ["DDPMPipeline"]
|
||||
_import_structure["dit"] = ["DiTPipeline"]
|
||||
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
|
||||
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
||||
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
|
||||
_import_structure["pndm"] = ["PNDMPipeline"]
|
||||
_import_structure["repaint"] = ["RePaintPipeline"]
|
||||
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
||||
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
from ..utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
||||
else:
|
||||
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
from ..utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
]
|
||||
_import_structure["kandinsky2_2"] = [
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CycleDiffusionPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion.clip_image_project_model import CLIPImageProjection
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
]
|
||||
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
|
||||
_import_structure["unidiffuser"] = [
|
||||
"ImageTextPipelineOutput",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
]
|
||||
_import_structure["versatile_diffusion"] = [
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
||||
_import_structure["wuerstchen"] = [
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
from ..utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
|
||||
else:
|
||||
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import *
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import (
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import (
|
||||
ImageTextPipelineOutput,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -1,38 +1,52 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
|
||||
class AltDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Alt Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
from .pipeline_alt_diffusion import AltDiffusionPipeline
|
||||
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
|
||||
_import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"]
|
||||
_import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"]
|
||||
_import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"]
|
||||
|
||||
_import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
from .pipeline_alt_diffusion import AltDiffusionPipeline
|
||||
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
|
||||
from .pipeline_output import AltDiffusionPipelineOutput
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -19,15 +19,14 @@ import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import deprecate, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||
@@ -99,7 +98,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -220,34 +221,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
|
||||
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
|
||||
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
|
||||
iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -749,9 +722,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -21,15 +21,14 @@ import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
|
||||
@@ -126,7 +125,9 @@ class AltDiffusionImg2ImgPipeline(
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -218,34 +219,6 @@ class AltDiffusionImg2ImgPipeline(
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
|
||||
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
|
||||
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
|
||||
iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -772,9 +745,8 @@ class AltDiffusionImg2ImgPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
|
||||
class AltDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Alt Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
@@ -1,2 +1,23 @@
|
||||
from .mel import Mel
|
||||
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"mel": ["Mel"],
|
||||
"pipeline_audio_diffusion": ["AudioDiffusionPipeline"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mel import Mel
|
||||
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from PIL import Image
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler
|
||||
from ...utils import randn_tensor
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .mel import Mel
|
||||
|
||||
@@ -178,7 +178,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(steps)
|
||||
step_generator = step_generator or generator
|
||||
# For backwards compatibility
|
||||
if type(self.unet.config.sample_size) == int:
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
if noise is None:
|
||||
noise = randn_tensor(
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -13,5 +19,32 @@ except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDMPipeline,
|
||||
)
|
||||
|
||||
_dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline})
|
||||
else:
|
||||
from .pipeline_audioldm import AudioLDMPipeline
|
||||
_import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDMPipeline,
|
||||
)
|
||||
|
||||
else:
|
||||
from .pipeline_audioldm import AudioLDMPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -22,7 +22,8 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import logging, randn_tensor, replace_example_docstring
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
|
||||
|
||||
@@ -71,6 +72,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,20 +1,49 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
)
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
||||
_import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
|
||||
_import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -36,9 +36,9 @@ from ...utils import (
|
||||
is_accelerate_version,
|
||||
is_librosa_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
|
||||
@@ -947,6 +947,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
@@ -52,6 +52,7 @@ from .stable_diffusion_xl import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
@@ -63,6 +64,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky22", KandinskyV22CombinedPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
||||
("wuerstchen", WuerstchenCombinedPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -93,6 +95,7 @@ _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("kandinsky", KandinskyPipeline),
|
||||
("kandinsky22", KandinskyV22Pipeline),
|
||||
("wuerstchen", WuerstchenDecoderPipeline),
|
||||
]
|
||||
)
|
||||
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
@@ -305,8 +308,6 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
@@ -316,8 +317,6 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
@@ -580,8 +579,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
@@ -591,8 +588,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
@@ -856,8 +851,6 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
@@ -867,8 +860,6 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
"use_auth_token": use_auth_token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
|
||||
@@ -1 +1,21 @@
|
||||
from .pipeline_consistency_models import ConsistencyModelPipeline
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
_LazyModule,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .pipeline_consistency_models import ConsistencyModelPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
@@ -5,12 +5,10 @@ import torch
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import CMStochasticIterativeScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
@@ -62,6 +60,7 @@ class ConsistencyModelPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
|
||||
compatible with [`CMStochasticIterativeScheduler`].
|
||||
"""
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
|
||||
super().__init__()
|
||||
@@ -73,34 +72,6 @@ class ConsistencyModelPipeline(DiffusionPipeline):
|
||||
|
||||
self.safety_checker = None
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
|
||||
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
|
||||
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
|
||||
iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.unet]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
@@ -280,9 +251,8 @@ class ConsistencyModelPipeline(DiffusionPipeline):
|
||||
# 6. Post-process image sample
|
||||
image = self.postprocess_image(sample, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
@@ -1,25 +1,77 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -29,13 +29,10 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -126,7 +123,9 @@ class StableDiffusionControlNetPipeline(
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -211,34 +210,6 @@ class StableDiffusionControlNetPipeline(
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# the safety checker can offload the vae again
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -1032,9 +1003,8 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -28,13 +28,10 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -150,7 +147,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -235,34 +234,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# the safety checker can offload the vae again
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -1108,9 +1079,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -30,13 +30,10 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -274,7 +271,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -362,34 +361,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# the safety checker can offload the vae again
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
@@ -1374,9 +1345,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -38,12 +38,11 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_compiled_module,
|
||||
is_invisible_watermark_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
@@ -167,6 +166,7 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
|
||||
def __init__(
|
||||
@@ -249,38 +249,6 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
model_sequence.extend([self.unet, self.vae])
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user