Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 95dae4c91e | |||
| cb62b4ff6b | |||
| bc7a4d4917 | |||
| 76d795a9a6 | |||
| 6b5ee298da | |||
| 062bb8dc0e | |||
| 5063e3b89d | |||
| 8dba180885 | |||
| 5366db5df1 | |||
| e516858886 | |||
| 36a0bacc29 | |||
| 9ad0530fea | |||
| 45db049973 | |||
| a68f5062fb | |||
| b864d674a5 | |||
| 85dccab7fd | |||
| 87fd3ce32b | |||
| 109d5bbe0d | |||
| f277d5e540 | |||
| 28e8d1f6ec | |||
| 98a0712d69 | |||
| 324d18fba2 | |||
| ad8068e414 | |||
| b4cbbd5ed2 | |||
| 8b3d2aeaf8 | |||
| 57239dacd0 | |||
| de12776b3a | |||
| cc12f3ec92 | |||
| 0ea78f9707 | |||
| 5495073faf | |||
| d03c9099bc | |||
| 93df5bb670 |
+3
-3
@@ -40,7 +40,7 @@ In the following, we give an overview of different ways to contribute, ranked by
|
||||
As said before, **all contributions are valuable to the community**.
|
||||
In the following, we will explain each contribution a bit more in detail.
|
||||
|
||||
For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull requst](#how-to-open-a-pr)
|
||||
For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr)
|
||||
|
||||
### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
|
||||
|
||||
@@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q
|
||||
|
||||
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
|
||||
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
|
||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accesible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
|
||||
**NOTE about channels**:
|
||||
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
|
||||
@@ -168,7 +168,7 @@ more precise, provide the link to a duplicated issue or redirect them to [the fo
|
||||
If you have verified that the issued bug report is correct and requires a correction in the source code,
|
||||
please have a look at the next sections.
|
||||
|
||||
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull requst](#how-to-open-a-pr) section.
|
||||
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
|
||||
|
||||
### 4. Fixing a "Good first issue"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
@@ -6,17 +6,17 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.9 \
|
||||
python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-venv && \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.9 \
|
||||
python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
@@ -26,21 +26,21 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
@@ -6,16 +6,16 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
@@ -25,21 +25,21 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
title: AutoPipeline
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
|
||||
@@ -34,7 +34,7 @@ this in the generated mask, you simply have to set the embeddings related to the
|
||||
`source_prompt` and "dog" to `target_prompt`.
|
||||
* When generating partially inverted latents using `invert`, assign a caption or text embedding describing the
|
||||
overall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the
|
||||
source concept is sufficently descriptive to yield good results, but feel free to explore alternatives.
|
||||
source concept is sufficiently descriptive to yield good results, but feel free to explore alternatives.
|
||||
* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt`
|
||||
and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to
|
||||
the phrases including "cat" to `negative_prompt` and "dog" to `prompt`.
|
||||
|
||||
@@ -396,7 +396,7 @@ t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
```
|
||||
|
||||
With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending
|
||||
on your hardware can signficantly speed-up your inference time once the model is compiled.
|
||||
on your hardware can significantly speed-up your inference time once the model is compiled.
|
||||
To use Kandinsksy with `torch.compile`, you can do:
|
||||
|
||||
```py
|
||||
|
||||
@@ -263,7 +263,7 @@ t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor())
|
||||
```
|
||||
|
||||
With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending
|
||||
on your hardware can signficantly speed-up your inference time once the model is compiled.
|
||||
on your hardware can significantly speed-up your inference time once the model is compiled.
|
||||
To use Kandinsksy with `torch.compile`, you can do:
|
||||
|
||||
```py
|
||||
|
||||
@@ -40,7 +40,7 @@ In the following, we give an overview of different ways to contribute, ranked by
|
||||
As said before, **all contributions are valuable to the community**.
|
||||
In the following, we will explain each contribution a bit more in detail.
|
||||
|
||||
For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull requst](#how-to-open-a-pr)
|
||||
For all contributions 4.-9. you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr)
|
||||
|
||||
### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
|
||||
|
||||
@@ -63,7 +63,7 @@ In the same spirit, you are of immense help to the community by answering such q
|
||||
|
||||
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
|
||||
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
|
||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accesible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
|
||||
**NOTE about channels**:
|
||||
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
|
||||
@@ -168,7 +168,7 @@ more precise, provide the link to a duplicated issue or redirect them to [the fo
|
||||
If you have verified that the issued bug report is correct and requires a correction in the source code,
|
||||
please have a look at the next sections.
|
||||
|
||||
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull requst](#how-to-open-a-pr) section.
|
||||
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
|
||||
|
||||
### 4. Fixing a `Good first issue`
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images[0]
|
||||
```
|
||||
|
||||
Depending on GPU type, `torch.compile` can provide an *addtional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs.
|
||||
Depending on GPU type, `torch.compile` can provide an *additional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs.
|
||||
|
||||
Compilation requires some time to complete, so it is best suited for situations where you prepare your pipeline once and then perform the same type of inference operations multiple times. For example, calling the compiled pipeline on a different image size triggers compilation again which can be expensive.
|
||||
|
||||
|
||||
@@ -87,4 +87,4 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
|
||||
Now that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script.
|
||||
|
||||
For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](uncondtional_training) or [text-to-image generation](text2image)!
|
||||
For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
|
||||
@@ -69,7 +69,7 @@ write_basic_config()
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
@@ -106,7 +106,7 @@ accelerate launch train_custom_diffusion.py \
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
|
||||
@@ -527,8 +527,8 @@ base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
|
||||
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
|
||||
|
||||
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
|
||||
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
|
||||
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, official wallpaper, glint <lora:kame_sdxl_v2:1>"
|
||||
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad proportions"
|
||||
generator = torch.manual_seed(2947883060)
|
||||
num_inference_steps = 30
|
||||
guidance_scale = 7
|
||||
|
||||
@@ -192,7 +192,7 @@ been added to the text encoder embedding matrix and consequently been trained.
|
||||
<Tip>
|
||||
|
||||
💡 The community has created a large library of different textual inversion embedding vectors, called [sd-concepts-library](https://huggingface.co/sd-concepts-library).
|
||||
Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the libary.
|
||||
Instead of training textual inversion embeddings from scratch you can also see whether a fitting textual inversion embedding has already been added to the library.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
@@ -284,22 +284,11 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
|
||||
|
||||
```py
|
||||
>>> from accelerate import Accelerator
|
||||
>>> from huggingface_hub import HfFolder, Repository, whoami
|
||||
>>> from huggingface_hub import create_repo, upload_folder
|
||||
>>> from tqdm.auto import tqdm
|
||||
>>> from pathlib import Path
|
||||
>>> import os
|
||||
|
||||
|
||||
>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
|
||||
... if token is None:
|
||||
... token = HfFolder.get_token()
|
||||
... if organization is None:
|
||||
... username = whoami(token)["name"]
|
||||
... return f"{username}/{model_id}"
|
||||
... else:
|
||||
... return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||
... # Initialize accelerator and tensorboard logging
|
||||
... accelerator = Accelerator(
|
||||
@@ -309,11 +298,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
|
||||
... project_dir=os.path.join(config.output_dir, "logs"),
|
||||
... )
|
||||
... if accelerator.is_main_process:
|
||||
... if config.push_to_hub:
|
||||
... repo_name = get_full_repo_name(Path(config.output_dir).name)
|
||||
... repo = Repository(config.output_dir, clone_from=repo_name)
|
||||
... elif config.output_dir is not None:
|
||||
... if config.output_dir is not None:
|
||||
... os.makedirs(config.output_dir, exist_ok=True)
|
||||
... if config.push_to_hub:
|
||||
... repo_id = create_repo(
|
||||
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
|
||||
... ).repo_id
|
||||
... accelerator.init_trackers("train_example")
|
||||
|
||||
... # Prepare everything
|
||||
@@ -371,7 +361,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
|
||||
|
||||
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
... if config.push_to_hub:
|
||||
... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
|
||||
... upload_folder(
|
||||
... repo_id=repo_id,
|
||||
... folder_path=config.output_dir,
|
||||
... commit_message=f"Epoch {epoch}",
|
||||
... ignore_patterns=["step_*", "epoch_*"],
|
||||
... )
|
||||
... else:
|
||||
... pipeline.save_pretrained(config.output_dir)
|
||||
```
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Inference with PEFT
|
||||
|
||||
There are many adapters trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](./pipelines/stable_diffusion/stable_diffusion_xl) for inference.
|
||||
|
||||
Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora).
|
||||
|
||||
Let's first install all the required libraries.
|
||||
|
||||
```bash
|
||||
!pip install -q transformers accelerate
|
||||
# Will be updated once the stable releases are done.
|
||||
!pip install -q git+https://github.com/huggingface/peft.git
|
||||
!pip install -q git+https://github.com/huggingface/diffusers.git
|
||||
```
|
||||
|
||||
Now, let's load a pipeline with a SDXL checkpoint:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
|
||||
Next, load a LoRA checkpoint with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method.
|
||||
|
||||
With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
```
|
||||
|
||||
And then perform inference:
|
||||
|
||||
```python
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
|
||||
lora_scale= 0.9
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images, and let's call it `"pixel"`.
|
||||
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter. But you can activate the `"pixel"` adapter with the [`~diffusers.loaders.set_adapters`] method as shown below:
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.set_adapters("pixel")
|
||||
```
|
||||
|
||||
Let's now generate an image with the second adapter and check the result:
|
||||
|
||||
```python
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Combine multiple adapters
|
||||
|
||||
You can also perform multi-adapter inference where you combine different adapter checkpoints for inference.
|
||||
|
||||
Once again, use the [`~diffusers.loaders.set_adapters`] method to activate two LoRA checkpoints and specify the weight for how the checkpoints should be combined.
|
||||
|
||||
```python
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
```
|
||||
|
||||
Now that we have set these two adapters, let's generate an image from the combined adapters!
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts.
|
||||
|
||||
</Tip>
|
||||
|
||||
The trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) are found in their repositories.
|
||||
|
||||
|
||||
```python
|
||||
# Notice how the prompt is constructed.
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Impressive! As you can see, the model was able to generate an image that mixes the characteristics of both adapters.
|
||||
|
||||
If you want to go back to using only one adapter, use the [`~diffusers.loaders.set_adapters`] method to activate the `"toy"` adapter:
|
||||
|
||||
```python
|
||||
# First, set the adapter.
|
||||
pipe.set_adapters("toy")
|
||||
|
||||
# Then, run inference.
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
lora_scale= 0.9
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
If you want to switch to only the base model, disable all LoRAs with the [`~diffusers.loaders.disable_lora`] method.
|
||||
|
||||
|
||||
```python
|
||||
pipe.disable_lora()
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
lora_scale= 0.9
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Monitoring active adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, you can easily check the list of active adapters using the [`~diffusers.loaders.get_active_adapters`] method:
|
||||
|
||||
```python
|
||||
active_adapters = pipe.get_active_adapters()
|
||||
>>> ["toy", "pixel"]
|
||||
```
|
||||
|
||||
You can also get the active adapters of each pipeline component with [`~diffusers.loaders.get_list_adapters`]:
|
||||
|
||||
```python
|
||||
list_adapters_component_wise = pipe.get_list_adapters()
|
||||
>>> {"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
@@ -434,7 +434,7 @@ high_threshold = 200
|
||||
|
||||
canny_image = cv2.Canny(canny_image, low_threshold, high_threshold)
|
||||
|
||||
# zero out middle columns of image where pose will be overlayed
|
||||
# zero out middle columns of image where pose will be overlaid
|
||||
zero_start = canny_image.shape[1] // 4
|
||||
zero_end = zero_start + canny_image.shape[1] // 2
|
||||
canny_image[:, zero_start:zero_end] = 0
|
||||
|
||||
@@ -68,7 +68,7 @@ The most popular image-to-image models are [Stable Diffusion v1.5](https://huggi
|
||||
|
||||
### Stable Diffusion v1.5
|
||||
|
||||
Stable Diffusion v1.5 is a latent diffusion model intialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
|
||||
Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
@@ -574,7 +574,7 @@ image
|
||||
|
||||
## Optimize
|
||||
|
||||
It can be difficult and slow to run diffusion models if you're resource constrained, but it dosen't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
|
||||
It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
|
||||
|
||||
You can also offload the model to the GPU to save even more memory:
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ exactly the same hardware and PyTorch version for full reproducibility.
|
||||
|
||||
You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. However, you should be aware that deterministic algorithms may be slower than nondeterministic ones and you may observe a decrease in performance. But if reproducibility is important to you, then this is the way to go!
|
||||
|
||||
Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment varibale [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
|
||||
Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment variable [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
|
||||
|
||||
PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Lastly, pass `True` to [`torch.use_deterministic_algorithms`](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) to enable deterministic algorithms.
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ export_to_gif(images[1], "cake_3d.gif")
|
||||
|
||||
## Image-to-3D
|
||||
|
||||
To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.
|
||||
To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
@@ -169,7 +169,7 @@ Feel free to choose any prompt you like if you want to generate something else!
|
||||
>>> width = 512 # default width of Stable Diffusion
|
||||
>>> num_inference_steps = 25 # Number of denoising steps
|
||||
>>> guidance_scale = 7.5 # Scale for classifier-free guidance
|
||||
>>> generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
|
||||
>>> generator = torch.manual_seed(0) # Seed generator to create the initial latent noise
|
||||
>>> batch_size = len(prompt)
|
||||
```
|
||||
|
||||
|
||||
@@ -283,36 +283,27 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
|
||||
|
||||
```py
|
||||
>>> from accelerate import Accelerator
|
||||
>>> from huggingface_hub import HfFolder, Repository, whoami
|
||||
>>> from huggingface_hub import create_repo, upload_folder
|
||||
>>> from tqdm.auto import tqdm
|
||||
>>> from pathlib import Path
|
||||
>>> import os
|
||||
|
||||
|
||||
>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
|
||||
... if token is None:
|
||||
... token = HfFolder.get_token()
|
||||
... if organization is None:
|
||||
... username = whoami(token)["name"]
|
||||
... return f"{username}/{model_id}"
|
||||
... else:
|
||||
... return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||
... # accelerator와 tensorboard 로깅 초기화
|
||||
... # Initialize accelerator and tensorboard logging
|
||||
... accelerator = Accelerator(
|
||||
... mixed_precision=config.mixed_precision,
|
||||
... gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
... log_with="tensorboard",
|
||||
... logging_dir=os.path.join(config.output_dir, "logs"),
|
||||
... project_dir=os.path.join(config.output_dir, "logs"),
|
||||
... )
|
||||
... if accelerator.is_main_process:
|
||||
... if config.push_to_hub:
|
||||
... repo_name = get_full_repo_name(Path(config.output_dir).name)
|
||||
... repo = Repository(config.output_dir, clone_from=repo_name)
|
||||
... elif config.output_dir is not None:
|
||||
... if config.output_dir is not None:
|
||||
... os.makedirs(config.output_dir, exist_ok=True)
|
||||
... if config.push_to_hub:
|
||||
... repo_id = create_repo(
|
||||
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
|
||||
... ).repo_id
|
||||
... accelerator.init_trackers("train_example")
|
||||
|
||||
... # 모든 것이 준비되었습니다.
|
||||
@@ -369,7 +360,12 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
|
||||
|
||||
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
... if config.push_to_hub:
|
||||
... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
|
||||
... upload_folder(
|
||||
... repo_id=repo_id,
|
||||
... folder_path=config.output_dir,
|
||||
... commit_message=f"Epoch {epoch}",
|
||||
... ignore_patterns=["step_*", "epoch_*"],
|
||||
... )
|
||||
... else:
|
||||
... pipeline.save_pretrained(config.output_dir)
|
||||
```
|
||||
|
||||
@@ -41,9 +41,10 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -765,7 +766,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
|
||||
#There are multiple possible scenarios:
|
||||
#The pipeline with the merged checkpoints is returned in all the scenarios
|
||||
|
||||
#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparision.( attrs with _ as prefix )
|
||||
#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix )
|
||||
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
|
||||
|
||||
#Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
|
||||
@@ -1529,14 +1530,14 @@ print("Latency of StableDiffusionPipeline--fp32",latency)
|
||||
|
||||

|
||||
|
||||
CLIP guided stable diffusion images mixing pipline allows to combine two images using standard diffusion models.
|
||||
CLIP guided stable diffusion images mixing pipeline allows to combine two images using standard diffusion models.
|
||||
This approach is using (optional) CoCa model to avoid writing image description.
|
||||
[More code examples](https://github.com/TheDenk/images_mixing)
|
||||
|
||||
|
||||
### Stable Diffusion XL Long Weighted Prompt Pipeline
|
||||
|
||||
This SDXL pipeline support unlimted length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
|
||||
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
|
||||
@@ -1605,7 +1606,7 @@ coca_transform = open_clip.image_transform(
|
||||
)
|
||||
coca_tokenizer = SimpleTokenizer()
|
||||
|
||||
# Pipline creating
|
||||
# Pipeline creating
|
||||
mixing_pipeline = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="clip_guided_images_mixing_stable_diffusion",
|
||||
@@ -1619,7 +1620,7 @@ mixing_pipeline = DiffusionPipeline.from_pretrained(
|
||||
mixing_pipeline.enable_attention_slicing()
|
||||
mixing_pipeline = mixing_pipeline.to("cuda")
|
||||
|
||||
# Pipline running
|
||||
# Pipeline running
|
||||
generator = torch.Generator(device="cuda").manual_seed(17)
|
||||
|
||||
def download_image(url):
|
||||
@@ -2147,3 +2148,40 @@ edit_kcross_attention_kwargswargs = {
|
||||
```
|
||||
|
||||
Side note: See [this GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps.
|
||||
|
||||
### Latent Consistency Pipeline
|
||||
|
||||
Latent Consistency Models was proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by *Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, Hang Zhao* from Tsinghua University.
|
||||
|
||||
The abstract of the paper reads as follows:
|
||||
|
||||
*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/)*
|
||||
|
||||
The model can be used with `diffusers` as follows:
|
||||
|
||||
- *1. Load the model from the community pipeline.*
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img")
|
||||
|
||||
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
pipe.to(torch_device="cuda", torch_dtype=torch.float32)
|
||||
```
|
||||
|
||||
- 2. Run inference with as little as 4 steps:
|
||||
|
||||
```py
|
||||
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
||||
|
||||
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
||||
num_inference_steps = 4
|
||||
|
||||
images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
|
||||
```
|
||||
|
||||
For any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).
|
||||
|
||||
You can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
|
||||
|
||||
+730
@@ -0,0 +1,730 @@
|
||||
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LatentConsistencyModelPipeline(DiffusionPipeline):
|
||||
_optional_components = ["scheduler"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: "LCMScheduler",
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scheduler = (
|
||||
scheduler
|
||||
if scheduler is not None
|
||||
else LCMScheduler(
|
||||
beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
|
||||
)
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
prompt_embeds: None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
pass
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
len(prompt)
|
||||
else:
|
||||
prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# Don't need to get uncond prompt embedding because of LCM Guided Distillation
|
||||
return prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
latents = torch.randn(shape, dtype=dtype).to(device)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
Args:
|
||||
timesteps: torch.Tensor: generate embedding vectors at these timesteps
|
||||
embedding_dim: int: dimension of the embeddings to generate
|
||||
dtype: data type of the generated embeddings
|
||||
Returns:
|
||||
embedding vectors with shape `(len(timesteps), embedding_dim)`
|
||||
"""
|
||||
assert len(w.shape) == 1
|
||||
w = w * 1000.0
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
||||
emb = w.to(dtype)[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1))
|
||||
assert emb.shape == (w.shape[0], embedding_dim)
|
||||
return emb
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = 768,
|
||||
width: Optional[int] = 768,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
num_inference_steps: int = 4,
|
||||
lcm_origin_steps: int = 50,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variable
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
latents,
|
||||
)
|
||||
bs = batch_size * num_images_per_prompt
|
||||
|
||||
# 6. Get Guidance Scale Embedding
|
||||
w = torch.tensor(guidance_scale).repeat(bs)
|
||||
w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 7. LCM MultiStep Sampling Loop:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
ts = torch.full((bs,), t, device=device, dtype=torch.long)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# model prediction (v-prediction, eps, x)
|
||||
model_pred = self.unet(
|
||||
latents,
|
||||
ts,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
|
||||
|
||||
# # call the callback, if provided
|
||||
# if i == len(timesteps) - 1:
|
||||
progress_bar.update()
|
||||
|
||||
denoised = denoised.to(prompt_embeds.dtype)
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = denoised
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class LCMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
denoised: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
||||
non-Markovian guidance.
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Clip the predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, defaults to `True`):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
"""
|
||||
|
||||
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# LCM Timesteps Setting: # Linear Spacing
|
||||
c = self.config.num_train_timesteps // lcm_origin_steps
|
||||
lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
|
||||
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
||||
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
|
||||
|
||||
def get_scalings_for_boundary_condition_discrete(self, t):
|
||||
self.sigma_data = 0.5 # Default: 0.5
|
||||
|
||||
# By dividing 0.1: This is almost a delta function at t=0.
|
||||
c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
|
||||
c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timeindex: int,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[LCMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||
`use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.FloatTensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# 1. get previous step value
|
||||
prev_timeindex = timeindex + 1
|
||||
if prev_timeindex < len(self.timesteps):
|
||||
prev_timestep = self.timesteps[prev_timeindex]
|
||||
else:
|
||||
prev_timestep = timestep
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# 3. Get scalings for boundary conditions
|
||||
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
||||
|
||||
# 4. Different Parameterization:
|
||||
parameterization = self.config.prediction_type
|
||||
|
||||
if parameterization == "epsilon": # noise-prediction
|
||||
pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
||||
|
||||
elif parameterization == "sample": # x-prediction
|
||||
pred_x0 = model_output
|
||||
|
||||
elif parameterization == "v_prediction": # v-prediction
|
||||
pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
||||
|
||||
# 4. Denoise model output using boundary conditions
|
||||
denoised = c_out * pred_x0 + c_skip * sample
|
||||
|
||||
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
||||
# Noise is not used for one-step sampling.
|
||||
if len(self.timesteps) > 1:
|
||||
noise = torch.randn(model_output.shape).to(model_output.device)
|
||||
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
||||
else:
|
||||
prev_sample = denoised
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, denoised)
|
||||
|
||||
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -553,7 +553,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accpet
|
||||
The initial image will be used as the starting point for the image generation process. Can also accept
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
|
||||
@@ -657,7 +657,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The initial image will be used as the starting point for the image generation process. Can also accpet
|
||||
The initial image will be used as the starting point for the image generation process. Can also accept
|
||||
image latents as `image`, if passing latents directly, it will not be encoded again.
|
||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
|
||||
@@ -48,7 +48,7 @@ write_basic_config()
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
@@ -82,7 +82,7 @@ accelerate launch train_custom_diffusion.py \
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
|
||||
@@ -1119,7 +1119,7 @@ def main(args):
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -16,7 +15,7 @@ import transformers
|
||||
from flax import jax_utils
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
@@ -318,16 +317,6 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def get_params_to_save(params):
|
||||
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
|
||||
|
||||
@@ -392,22 +381,14 @@ def main():
|
||||
|
||||
# Handle the repository creation
|
||||
if jax.process_index() == 0:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
@@ -668,7 +649,12 @@ def main():
|
||||
|
||||
if args.push_to_hub:
|
||||
message = f"checkpoint-{step}" if step is not None else "End of training"
|
||||
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message=message,
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
|
||||
|
||||
@@ -794,7 +794,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
|
||||
@@ -707,7 +707,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
@@ -755,26 +755,17 @@ def main(args):
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features,
|
||||
out_features=attn_module.to_q.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features,
|
||||
out_features=attn_module.to_k.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features,
|
||||
out_features=attn_module.to_v.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
@@ -782,7 +773,6 @@ def main(args):
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode
|
||||
|
||||
## Training
|
||||
|
||||
The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
|
||||
The argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
@@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from neural_compressor.utils import logger
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
@@ -413,16 +413,6 @@ class TextualInversionDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
@@ -461,21 +451,14 @@ def main():
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
@@ -982,7 +965,12 @@ def main():
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
### Using DreamBooth for other pipelines than Stable Diffusion
|
||||
|
||||
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
|
||||
Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
|
||||
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
|
||||
|
||||
```
|
||||
|
||||
+13
-27
@@ -4,7 +4,6 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
@@ -14,7 +13,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from packaging import version
|
||||
@@ -277,16 +276,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(
|
||||
@@ -360,22 +349,14 @@ def main(args):
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Initialize the model
|
||||
if args.model_config_name_or_path is None:
|
||||
model = UNet2DModel(
|
||||
@@ -691,7 +672,12 @@ def main(args):
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message=f"Epoch {epoch}",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ telling JAX which input arguments are static, that is, arguments that
|
||||
are known at compile time and won't change. In our case, it is num_inference_steps,
|
||||
height, width and return_latents.
|
||||
|
||||
Once the function is compiled, these parameters are ommited from future calls and
|
||||
Once the function is compiled, these parameters are omitted from future calls and
|
||||
cannot be changed without modifying the code and recompiling.
|
||||
|
||||
```python
|
||||
|
||||
@@ -839,7 +839,7 @@ def main(args):
|
||||
all_images = []
|
||||
crop_top_lefts = []
|
||||
for image in images:
|
||||
original_sizes.append((image.height, image.width))
|
||||
original_sizes.append((image.width, image.height))
|
||||
image = train_resize(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
|
||||
@@ -825,7 +825,7 @@ def main(args):
|
||||
all_images = []
|
||||
crop_top_lefts = []
|
||||
for image in images:
|
||||
original_sizes.append((image.height, image.width))
|
||||
original_sizes.append((image.width, image.height))
|
||||
image = train_resize(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
@@ -1038,7 +1038,6 @@ def main(args):
|
||||
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
prompt_embeds = prompt_embeds
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
|
||||
@@ -6,7 +6,6 @@ import os
|
||||
import shutil
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
@@ -16,7 +15,7 @@ from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -273,16 +272,6 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -356,22 +345,14 @@ def main(args):
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Initialize the model
|
||||
if args.model_config_name_or_path is None:
|
||||
model = UNet2DModel(
|
||||
@@ -413,6 +394,14 @@ def main(args):
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
args.mixed_precision = accelerator.mixed_precision
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
args.mixed_precision = accelerator.mixed_precision
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -559,11 +548,9 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
clean_images = batch["input"]
|
||||
clean_images = batch["input"].to(weight_dtype)
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(
|
||||
clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
|
||||
).to(clean_images.device)
|
||||
noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -579,15 +566,14 @@ def main(args):
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
|
||||
elif args.prediction_type == "sample":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
snr_weights = alpha_t / (1 - alpha_t)
|
||||
loss = snr_weights * F.mse_loss(
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
# use SNR weighting from distillation paper
|
||||
loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
|
||||
loss = loss.mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
|
||||
@@ -703,7 +689,12 @@ def main(args):
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message=f"Epoch {epoch}",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Würstchen text-to-image fine-tuning
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd into the example folder and run
|
||||
```bash
|
||||
cd examples/wuerstchen/text_to_image
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Prior training
|
||||
|
||||
You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
|
||||
|
||||
<br>
|
||||
|
||||
<!-- accelerate_snippet_start -->
|
||||
```bash
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_prior.py \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--resolution=768 \
|
||||
--train_batch_size=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--dataloader_num_workers=4 \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--checkpoints_total_limit=3 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--validation_prompts="A robot pokemon, 4k photo" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="wuerstchen-prior-pokemon-model"
|
||||
```
|
||||
<!-- accelerate_snippet_end -->
|
||||
|
||||
## Training with LoRA
|
||||
|
||||
Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
|
||||
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
||||
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
|
||||
|
||||
|
||||
### Prior Training
|
||||
|
||||
First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
```bash
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_prior_lora.py \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=768 \
|
||||
--train_batch_size=8 \
|
||||
--num_train_epochs=100 --checkpointing_steps=5000 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--seed=42 \
|
||||
--rank=4 \
|
||||
--validation_prompt="cute dragon creature" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="wuerstchen-prior-pokemon-lora"
|
||||
```
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch.nn as nn
|
||||
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class EfficientNetEncoder(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
|
||||
super().__init__()
|
||||
|
||||
if effnet == "efficientnet_v2_s":
|
||||
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
|
||||
else:
|
||||
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mapper(self.backbone(x))
|
||||
@@ -0,0 +1,7 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
wandb
|
||||
huggingface-cli
|
||||
bitsandbytes
|
||||
deepspeed
|
||||
@@ -0,0 +1,888 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState, is_initialized
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from modeling_efficient_net_encoder import EfficientNetEncoder
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.22.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: mit
|
||||
base_model: {args.pretrained_prior_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA Finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
{img_str}
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}
|
||||
)
|
||||
# load lora weights from folder:
|
||||
pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype})
|
||||
|
||||
image = pipeline(prompt=prompt).images[0]
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* LoRA rank: {args.rank}
|
||||
* Epochs: {args.num_train_epochs}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
wandb_info = ""
|
||||
if is_wandb_available():
|
||||
wandb_run_url = None
|
||||
if wandb.run is not None:
|
||||
wandb_run_url = wandb.run.url
|
||||
|
||||
if wandb_run_url is not None:
|
||||
wandb_info = f"""
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_card += wandb_info
|
||||
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.prior_prior.set_attn_processor(attn_processors)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
height=args.resolution,
|
||||
width=args.resolution,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
elif tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_decoder_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_prior_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen-prior",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompts",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="wuerstchen-model-finetuned-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
required=False,
|
||||
help="weight decay_to_use",
|
||||
)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run validation every X epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
default="text2image-fine-tune",
|
||||
help=(
|
||||
"The `project_name` argument passed to Accelerator.init_trackers for"
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load scheduler, effnet, tokenizer, clip_model
|
||||
noise_scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
|
||||
)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
|
||||
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
|
||||
image_encoder = EfficientNetEncoder()
|
||||
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
|
||||
image_encoder.eval()
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
||||
).eval()
|
||||
|
||||
# Freeze text_encoder, cast to weight_dtype and image_encoder and move to device
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# load prior model, cast to weight_dtype and move to device
|
||||
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
prior.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# lora attn processor
|
||||
lora_attn_procs = {}
|
||||
for name in prior.attn_processors.keys():
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
|
||||
prior.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(prior.attn_processors)
|
||||
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
optimizer = optimizer_cls(
|
||||
lora_layers.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
text_input_ids = inputs.input_ids
|
||||
text_mask = inputs.attention_mask.bool()
|
||||
return text_input_ids, text_mask
|
||||
|
||||
effnet_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
|
||||
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
|
||||
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
||||
text_mask = torch.stack([example["text_mask"] for example in examples])
|
||||
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = dict(vars(args))
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, effnet_images = (
|
||||
batch["text_input_ids"],
|
||||
batch["text_mask"],
|
||||
batch["effnet_pixel_values"].to(weight_dtype),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
|
||||
prompt_embeds = text_encoder_output.last_hidden_state
|
||||
image_embeds = image_encoder(effnet_images)
|
||||
# scale
|
||||
image_embeds = image_embeds.add(1.0).div(42.0)
|
||||
|
||||
# Sample noise that we'll add to the image_embeds
|
||||
noise = torch.randn_like(image_embeds)
|
||||
bsz = image_embeds.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
|
||||
|
||||
# add noise to latent
|
||||
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
||||
|
||||
# Predict the noise residual and compute losscd
|
||||
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
|
||||
|
||||
# vanilla loss
|
||||
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
||||
log_validation(
|
||||
text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
|
||||
)
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
prior = prior.to(torch.float32)
|
||||
WuerstchenPriorPipeline.save_lora_weights(
|
||||
os.path.join(args.output_dir, "prior_lora"),
|
||||
unet_lora_layers=lora_layers,
|
||||
)
|
||||
|
||||
# Run a final round of inference.
|
||||
images = []
|
||||
if args.validation_prompts is not None:
|
||||
logger.info("Running inference for collecting generated images...")
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
|
||||
# load lora weights
|
||||
pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
|
||||
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
width=args.resolution,
|
||||
height=args.resolution,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,925 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState, is_initialized
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from modeling_efficient_net_encoder import EfficientNetEncoder
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.22.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: mit
|
||||
base_model: {args.pretrained_prior_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
{img_str}
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype})
|
||||
pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype})
|
||||
prompt = "{args.validation_prompts[0]}"
|
||||
(image_embeds,) = pipe_prior(prompt).to_tuple()
|
||||
image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* Epochs: {args.num_train_epochs}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
wandb_info = ""
|
||||
if is_wandb_available():
|
||||
wandb_run_url = None
|
||||
if wandb.run is not None:
|
||||
wandb_run_url = wandb.run.url
|
||||
|
||||
if wandb_run_url is not None:
|
||||
wandb_info = f"""
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_card += wandb_info
|
||||
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_prior=accelerator.unwrap_model(prior),
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
height=args.resolution,
|
||||
width=args.resolution,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
elif tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
|
||||
parser.add_argument(
|
||||
"--pretrained_decoder_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_prior_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen-prior",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompts",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="wuerstchen-model-finetuned",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
required=False,
|
||||
help="weight decay_to_use",
|
||||
)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run validation every X epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
default="text2image-fine-tune",
|
||||
help=(
|
||||
"The `project_name` argument passed to Accelerator.init_trackers for"
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load scheduler, effnet, tokenizer, clip_model
|
||||
noise_scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
|
||||
)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
|
||||
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
|
||||
image_encoder = EfficientNetEncoder()
|
||||
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
|
||||
image_encoder.eval()
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
||||
).eval()
|
||||
|
||||
# Freeze text_encoder and image_encoder
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# load prior model
|
||||
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
|
||||
# Create EMA for the prior
|
||||
if args.use_ema:
|
||||
ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)
|
||||
ema_prior.to(accelerator.device)
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "prior"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior)
|
||||
ema_prior.load_state_dict(load_model.state_dict())
|
||||
ema_prior.to(accelerator.device)
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
prior.enable_gradient_checkpointing()
|
||||
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
optimizer = optimizer_cls(
|
||||
prior.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
text_input_ids = inputs.input_ids
|
||||
text_mask = inputs.attention_mask.bool()
|
||||
return text_input_ids, text_mask
|
||||
|
||||
effnet_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
|
||||
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
|
||||
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
||||
text_mask = torch.stack([example["text_mask"] for example in examples])
|
||||
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
prior, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = dict(vars(args))
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, effnet_images = (
|
||||
batch["text_input_ids"],
|
||||
batch["text_mask"],
|
||||
batch["effnet_pixel_values"].to(weight_dtype),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
|
||||
prompt_embeds = text_encoder_output.last_hidden_state
|
||||
image_embeds = image_encoder(effnet_images)
|
||||
# scale
|
||||
image_embeds = image_embeds.add(1.0).div(42.0)
|
||||
|
||||
# Sample noise that we'll add to the image_embeds
|
||||
noise = torch.randn_like(image_embeds)
|
||||
bsz = image_embeds.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
|
||||
|
||||
# add noise to latent
|
||||
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
||||
|
||||
# Predict the noise residual and compute losscd
|
||||
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
|
||||
|
||||
# vanilla loss
|
||||
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_prior.step(prior.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_prior.store(prior.parameters())
|
||||
ema_prior.copy_to(prior.parameters())
|
||||
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_prior.restore(prior.parameters())
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
prior = accelerator.unwrap_model(prior)
|
||||
if args.use_ema:
|
||||
ema_prior.copy_to(prior.parameters())
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_prior=prior,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
)
|
||||
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline"))
|
||||
|
||||
# Run a final round of inference.
|
||||
images = []
|
||||
if args.validation_prompts is not None:
|
||||
logger.info("Running inference for collecting generated images...")
|
||||
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
width=args.resolution,
|
||||
height=args.resolution,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -343,6 +343,7 @@ class ConfigMixin:
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
print("load_config() is called.")
|
||||
|
||||
if cls.config_name is None:
|
||||
raise ValueError(
|
||||
@@ -485,10 +486,18 @@ class ConfigMixin:
|
||||
|
||||
# remove attributes from orig class that cannot be expected
|
||||
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
||||
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
||||
if (
|
||||
isinstance(orig_cls_name, str)
|
||||
and orig_cls_name != cls.__name__
|
||||
and hasattr(diffusers_library, orig_cls_name)
|
||||
):
|
||||
orig_cls = getattr(diffusers_library, orig_cls_name)
|
||||
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
||||
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
||||
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
|
||||
)
|
||||
|
||||
# remove private attributes
|
||||
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
||||
|
||||
@@ -1208,7 +1208,7 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
@@ -1216,7 +1216,9 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1577,7 +1579,7 @@ class LoraLoaderMixin:
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
@@ -1961,7 +1963,7 @@ class LoraLoaderMixin:
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
@@ -2001,7 +2003,7 @@ class LoraLoaderMixin:
|
||||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
||||
)
|
||||
|
||||
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
state_dict.update(unet_lora_state_dict)
|
||||
|
||||
if text_encoder_lora_layers is not None:
|
||||
@@ -2012,12 +2014,12 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
text_encoder_lora_state_dict = {
|
||||
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
}
|
||||
state_dict.update(text_encoder_lora_state_dict)
|
||||
|
||||
# Save the model
|
||||
self.write_lora_layers(
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
@@ -2026,6 +2028,7 @@ class LoraLoaderMixin:
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
@@ -2829,6 +2832,7 @@ class FromSingleFileMixin:
|
||||
tokenizer=tokenizer,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
@@ -3248,7 +3252,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
@@ -3299,7 +3303,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
|
||||
@@ -207,10 +207,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# print("After first norm")
|
||||
# print(f"hidden_states: {hidden_states.dtype}")
|
||||
# print(f"norm_hidden_states: {norm_hidden_states.dtype}")
|
||||
# print(f"encoder_hidden_states: {norm_hidden_states.dtype}")
|
||||
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
@@ -227,9 +223,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
hidden_states = attn_output + hidden_states
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
|
||||
@@ -84,7 +84,6 @@ class LoRALinearLayer(nn.Module):
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# print(f"From {self.__class__.__name__}: hidden_states: {hidden_states.dtype}")
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -94,9 +93,7 @@ class LoRALinearLayer(nn.Module):
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
out = up_hidden_states.to(orig_dtype)
|
||||
# print(f"From {self.__class__.__name__}: out: {out.dtype}")
|
||||
return out
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
|
||||
@@ -291,6 +291,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -503,6 +503,36 @@ class AutoencoderTinyBlock(nn.Module):
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
temb_channels (`int`): The number of temporal embedding channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
||||
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
||||
model on tasks with long-range temporal dependencies.
|
||||
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
||||
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pre-normalization for the resnet blocks.
|
||||
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
||||
attention_head_dim (`int`, *optional*, defaults to 1):
|
||||
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
||||
the number of input channels.
|
||||
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -604,7 +634,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -624,6 +654,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# support for variable transformer layers per block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
@@ -641,14 +675,14 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
for i in range(num_layers):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -988,7 +1022,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1011,6 +1045,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -1034,7 +1070,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -2137,7 +2173,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -2160,6 +2196,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -2184,7 +2223,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -43,6 +43,7 @@ from .embeddings import (
|
||||
)
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UNetMidBlock2DSimpleCrossAttn,
|
||||
get_down_block,
|
||||
@@ -86,7 +87,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
@@ -105,10 +106,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
||||
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
@@ -142,9 +148,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of `cond_proj` layer in the timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
||||
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
||||
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
||||
*optional*): The dimension of the `class_labels` input when
|
||||
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
@@ -184,7 +190,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
@@ -265,6 +272,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||
for layer_number_per_block in transformer_layers_per_block:
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
@@ -500,6 +511,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
only_cross_attention=mid_block_only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2D":
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
add_attention=False,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
@@ -513,7 +537,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
reversed_transformer_layers_per_block = (
|
||||
list(reversed(transformer_layers_per_block))
|
||||
if reverse_transformer_layers_per_block is None
|
||||
else reverse_transformer_layers_per_block
|
||||
)
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -778,6 +806,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
@@ -822,6 +851,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||
example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
@@ -1000,15 +1036,30 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||
deprecate(
|
||||
"T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False,
|
||||
)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
@@ -1021,9 +1072,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
@@ -1040,21 +1090,25 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_block_additional_residuals) > 0
|
||||
and sample.shape == down_block_additional_residuals[0].shape
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
@@ -1099,7 +1153,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self)
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
@@ -106,6 +106,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
@@ -441,7 +442,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
@@ -440,7 +441,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -448,7 +448,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -575,7 +575,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -317,12 +317,17 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -438,7 +443,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -447,7 +456,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -459,10 +473,15 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -885,7 +904,14 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
|
||||
self,
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -895,7 +921,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1391,6 +1417,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1398,6 +1429,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline(
|
||||
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
|
||||
watermarker is used.
|
||||
"""
|
||||
model_cpu_offload_seq = (
|
||||
"text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet
|
||||
)
|
||||
# leave controlnet out on purpose because it iterates with unet
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -285,12 +285,17 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -406,7 +411,11 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -415,7 +424,12 @@ class StableDiffusionXLControlNetPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -427,10 +441,15 @@ class StableDiffusionXLControlNetPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -706,11 +725,13 @@ class StableDiffusionXLControlNetPipeline(
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1088,8 +1109,17 @@ class StableDiffusionXLControlNetPipeline(
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
@@ -1098,6 +1128,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
@@ -183,7 +183,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -329,12 +329,17 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -450,7 +455,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -459,7 +468,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -471,10 +485,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -832,6 +851,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -843,7 +863,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1275,6 +1295,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
if negative_target_size is None:
|
||||
negative_target_size = target_size
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1285,6 +1311,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
@@ -161,11 +161,11 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -174,14 +174,14 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -193,6 +193,8 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -568,13 +570,13 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -184,14 +184,13 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
self.final_offload_hook = None
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -200,14 +199,14 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -219,6 +218,8 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -686,19 +687,19 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
strength (`float`, *optional*, defaults to 0.7):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
num_inference_steps (`int`, *optional*, defaults to 80):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 10.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -338,11 +338,11 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -351,14 +351,14 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -370,6 +370,8 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -784,7 +786,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -190,11 +190,11 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -203,14 +203,14 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -222,6 +222,8 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -786,7 +788,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
@@ -798,7 +800,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -340,11 +340,11 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -353,14 +353,14 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -372,6 +372,8 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -874,13 +876,13 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -296,11 +296,11 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
@@ -309,14 +309,14 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
@@ -328,6 +328,8 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
@@ -637,19 +639,19 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
height (`int`, *optional*, defaults to None):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
width (`int`, *optional*, defaults to None):
|
||||
The width in pixels of the generated image.
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
|
||||
The image to be upscaled.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
timesteps (`List[int]`, *optional*, defaults to None):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -305,13 +305,22 @@ def maybe_raise_or_warn(
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif library_name not in LOADABLE_CLASSES.keys():
|
||||
# load custom component
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
@@ -323,7 +332,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
|
||||
class_obj,
|
||||
config,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
hub_repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
@@ -331,11 +348,19 @@ def _get_pipeline_class(
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif hub_repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = hub_repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
hub_repo_id=hub_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision if hub_revision is None else hub_revision,
|
||||
)
|
||||
|
||||
if class_obj != DiffusionPipeline:
|
||||
@@ -383,11 +408,18 @@ def load_sub_model(
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
revision: str = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
@@ -1080,11 +1112,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1223,6 +1265,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
revision=revision,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1542,6 +1585,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
|
||||
with `.onnx` and `.pb`.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||
it will execute code present on the Hub on your local machine.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1569,6 +1616,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -1604,12 +1652,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
# optionally create a custom component <> custom file mapping
|
||||
custom_components = {}
|
||||
for component in folder_names:
|
||||
if config_dict[component][0] not in LOADABLE_CLASSES.keys():
|
||||
custom_components[component] = config_dict[component][0]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
@@ -1636,12 +1689,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
custom_class_name = None
|
||||
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
|
||||
custom_pipeline = config_dict["_class_name"][0]
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
# all filenames compatible with variant will be added
|
||||
allow_patterns = list(model_filenames)
|
||||
|
||||
# allow all patterns from non-model folders
|
||||
# this enables downloading schedulers, tokenizers, ...
|
||||
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
|
||||
# add custom component files
|
||||
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
|
||||
# add custom pipeline file
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
@@ -1652,12 +1714,32 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
]
|
||||
|
||||
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
|
||||
load_components_from_hub = len(custom_components) > 0
|
||||
|
||||
if load_pipe_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
if load_components_from_hub and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
|
||||
# retrieve passed components that should not be downloaded
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
hub_repo_id=pretrained_model_name if load_pipe_from_hub else None,
|
||||
hub_revision=revision,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
@@ -1754,7 +1836,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
|
||||
pipeline_class = getattr(diffusers, cls_name, None)
|
||||
|
||||
|
||||
@@ -911,7 +911,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
n_coarse_samples=64,
|
||||
n_fine_samples=128,
|
||||
):
|
||||
# project the the paramters from the generated latents
|
||||
# project the parameters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# update the mlp layers of the renderer
|
||||
@@ -955,7 +955,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
|
||||
query_batch_size: int = 4096,
|
||||
texture_channels: Tuple = ("R", "G", "B"),
|
||||
):
|
||||
# 1. project the the paramters from the generated latents
|
||||
# 1. project the parameters from the generated latents
|
||||
projected_params = self.params_proj(latents)
|
||||
|
||||
# 2. update the mlp layers of the renderer
|
||||
|
||||
@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if "disable_self_attentions" in unet_params:
|
||||
config["only_cross_attention"] = unet_params.disable_self_attentions
|
||||
|
||||
if "num_classes" in unet_params and type(unet_params.num_classes) == int:
|
||||
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
||||
config["num_class_embeds"] = unet_params.num_classes
|
||||
|
||||
if controlnet:
|
||||
@@ -787,7 +787,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||
if text_encoder is None:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
try:
|
||||
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
@@ -922,7 +927,12 @@ def convert_open_clip_checkpoint(
|
||||
# text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
# )
|
||||
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
|
||||
try:
|
||||
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
@@ -1464,11 +1474,19 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
|
||||
text_model = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'."
|
||||
)
|
||||
|
||||
if stable_unclip is None:
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
@@ -1546,9 +1564,14 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
karlo_model, subfolder="prior", local_files_only=local_files_only
|
||||
)
|
||||
|
||||
prior_tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
prior_tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
@@ -1581,10 +1604,22 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
|
||||
elif model_type == "PaintByExample":
|
||||
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
try:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
|
||||
)
|
||||
pipe = PaintByExamplePipeline(
|
||||
vae=vae,
|
||||
image_encoder=vision_model,
|
||||
@@ -1597,11 +1632,16 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
text_model = convert_ldm_clip_checkpoint(
|
||||
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
||||
)
|
||||
tokenizer = (
|
||||
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
if tokenizer is None
|
||||
else tokenizer
|
||||
)
|
||||
try:
|
||||
tokenizer = (
|
||||
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
if tokenizer is None
|
||||
else tokenizer
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
if load_safety_checker:
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
@@ -1637,18 +1677,33 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
if model_type == "SDXL":
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
||||
)
|
||||
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix="conditioner.embedders.1.model.",
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
@@ -1682,14 +1737,23 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
else:
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
)
|
||||
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
||||
)
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix="conditioner.embedders.0.model.",
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
|
||||
@@ -438,7 +438,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -434,7 +434,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
+1
-1
@@ -469,7 +469,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -343,7 +343,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -614,7 +614,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -411,7 +411,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
+1
-1
@@ -436,7 +436,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -435,7 +435,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -505,7 +505,7 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
+1
-1
@@ -427,7 +427,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -341,7 +341,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -405,7 +405,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -358,7 +358,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -389,7 +389,7 @@ class StableDiffusionParadigmsPipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -579,7 +579,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -381,7 +381,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -372,7 +372,7 @@ class StableDiffusionUpscalePipeline(
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -479,7 +479,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -433,7 +433,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ class StableDiffusionXLPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -167,6 +168,7 @@ class StableDiffusionXLPipeline(
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
@@ -275,12 +277,17 @@ class StableDiffusionXLPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -396,7 +403,11 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -405,7 +416,12 @@ class StableDiffusionXLPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -417,10 +433,15 @@ class StableDiffusionXLPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -533,11 +554,13 @@ class StableDiffusionXLPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -843,8 +866,17 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
@@ -852,6 +884,7 @@ class StableDiffusionXLPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
+40
-15
@@ -143,8 +143,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -282,12 +281,17 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -403,7 +407,11 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -412,7 +420,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -424,10 +437,15 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -618,6 +636,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -629,7 +648,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -983,6 +1002,11 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -993,6 +1017,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
+40
-14
@@ -290,7 +290,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -431,12 +431,17 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -552,7 +557,11 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -561,7 +570,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -573,10 +587,15 @@ class StableDiffusionXLInpaintPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -836,6 +855,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
text_encoder_projection_dim=None,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
@@ -847,7 +867,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -1289,6 +1309,11 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
@@ -1299,6 +1324,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
|
||||
+31
-7
@@ -31,11 +31,13 @@ from ...models.attention_processor import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -150,6 +152,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
watermarker will be used.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -280,8 +283,17 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -390,7 +402,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -399,7 +412,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -552,11 +565,13 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -871,8 +886,17 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
@@ -429,7 +429,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
@@ -813,7 +813,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=[state.clone() for state in adapter_state],
|
||||
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -160,6 +160,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -290,12 +291,17 @@ class StableDiffusionXLAdapterPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
if self.text_encoder is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -411,7 +417,11 @@ class StableDiffusionXLAdapterPipeline(
|
||||
|
||||
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
if self.text_encoder_2 is not None:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -420,7 +430,12 @@ class StableDiffusionXLAdapterPipeline(
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
else:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
@@ -432,10 +447,15 @@ class StableDiffusionXLAdapterPipeline(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder_2)
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -550,11 +570,13 @@ class StableDiffusionXLAdapterPipeline(
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
@@ -928,8 +950,17 @@ class StableDiffusionXLAdapterPipeline(
|
||||
adapter_state[k] = torch.cat([v] * 2, dim=0)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
@@ -937,6 +968,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
@@ -975,9 +1007,9 @@ class StableDiffusionXLAdapterPipeline(
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
if i < int(num_inference_steps * adapter_conditioning_factor):
|
||||
down_block_additional_residuals = [state.clone() for state in adapter_state]
|
||||
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
|
||||
else:
|
||||
down_block_additional_residuals = None
|
||||
down_intrablock_additional_residuals = None
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
@@ -986,7 +1018,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -361,7 +361,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
+1
-1
@@ -423,7 +423,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
||||
prefix_length (`int`):
|
||||
Max number of prefix tokens that will be supplied to the model.
|
||||
prefix_inner_dim (`int`):
|
||||
The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
|
||||
The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
|
||||
CLIP text encoder.
|
||||
prefix_hidden_dim (`int`, *optional*):
|
||||
Hidden dim of the MLP if we encode the prefix.
|
||||
|
||||
@@ -556,7 +556,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.utils import deprecate
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.activations import get_activation
|
||||
@@ -279,7 +281,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or
|
||||
`UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
|
||||
The tuple of upsample blocks to use.
|
||||
@@ -298,10 +300,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
|
||||
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
||||
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
@@ -335,9 +342,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of `cond_proj` layer in the timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
||||
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
||||
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
||||
*optional*): The dimension of the `class_labels` input when
|
||||
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
@@ -382,7 +389,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
@@ -473,6 +481,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
|
||||
f" {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||
for layer_number_per_block in transformer_layers_per_block:
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
@@ -708,6 +720,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=mid_block_only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlockFlat":
|
||||
self.mid_block = UNetMidBlockFlat(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
add_attention=False,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
@@ -721,7 +746,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
reversed_transformer_layers_per_block = (
|
||||
list(reversed(transformer_layers_per_block))
|
||||
if reverse_transformer_layers_per_block is None
|
||||
else reverse_transformer_layers_per_block
|
||||
)
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -987,6 +1016,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
@@ -1031,6 +1061,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||
example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
@@ -1216,15 +1253,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||
deprecate(
|
||||
"T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
|
||||
" and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
|
||||
" be used for ControlNet. Please make sure use"
|
||||
" `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False,
|
||||
)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlockFlat
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
@@ -1237,9 +1290,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
||||
|
||||
if is_adapter and len(down_block_additional_residuals) > 0:
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
@@ -1256,21 +1308,25 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_block_additional_residuals) > 0
|
||||
and sample.shape == down_block_additional_residuals[0].shape
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_block_additional_residuals.pop(0)
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
@@ -1315,7 +1371,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self)
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
@@ -1532,7 +1588,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1555,6 +1611,8 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -1578,7 +1636,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1798,7 +1856,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1821,6 +1879,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -1845,7 +1906,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -1958,6 +2019,133 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
temb_channels (`int`): The number of temporal embedding channels.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
||||
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
||||
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
||||
model on tasks with long-range temporal dependencies.
|
||||
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
||||
resnet_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use in the group normalization layers of the resnet blocks.
|
||||
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
||||
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use pre-normalization for the resnet blocks.
|
||||
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
||||
attention_head_dim (`int`, *optional*, defaults to 1):
|
||||
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
||||
the number of input channels.
|
||||
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
||||
in_channels, height, width)`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default", # default, spatial
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
attn_groups: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
attention_head_dim=1,
|
||||
output_scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
if attn_groups is None:
|
||||
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
logger.warn(
|
||||
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
|
||||
f" `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
heads=in_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=attn_groups,
|
||||
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
@@ -1966,7 +2154,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -1986,6 +2174,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# support for variable transformer layers per block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
@@ -2003,14 +2195,14 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
for i in range(num_layers):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user