Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b682be902 | |||
| 3d0bb51d53 | |||
| 4b72aae0cd | |||
| 33bbe58ea7 | |||
| a1cb106459 | |||
| 5dd8e04d4b | |||
| 165af7edd3 | |||
| 6c5f0de713 | |||
| e64fdcf2ce | |||
| ec64f371b1 | |||
| cd6e1f1171 | |||
| 6f2b310a17 | |||
| e3cd6cae50 | |||
| e5ee05da76 | |||
| e6ff752840 | |||
| 3f9c746fb2 | |||
| 1f22c98820 | |||
| b4226bd6a7 | |||
| 46fac824be | |||
| b33b64f595 | |||
| 9d9744075e | |||
| d9a3b69806 | |||
| f7e5954d5e | |||
| 8e19c073e5 | |||
| f6df16cbb8 | |||
| b24f78349c | |||
| 3ce905c9d0 | |||
| f539497ab4 | |||
| 39dfb7abbd | |||
| 196835695e | |||
| 0d4dfbbd0a | |||
| ada3bb941b | |||
| b5814c5555 | |||
| 9940573618 | |||
| 59433ca1ae | |||
| 534f5d54fa | |||
| 40aa47b998 | |||
| 1bc0d37ffe | |||
| eb942b866a | |||
| 687bc27727 | |||
| 6246c70d21 | |||
| 577b8a2783 | |||
| 13f0c8b219 | |||
| fa1bdce3d4 | |||
| ca6cdc77a9 | |||
| f4977abcd8 | |||
| df8559a7f9 | |||
| 8f206a5873 | |||
| 8da360aa12 | |||
| 869bad3e52 | |||
| 01ee0978cc | |||
| 56b68459f5 | |||
| 2ca264244b | |||
| b9e1c30d0e | |||
| 03cd62520f | |||
| 001b14023e | |||
| f55873b783 | |||
| ccb93dcad1 | |||
| ec953047bc | |||
| 9a2600ede9 | |||
| 5f150c4cef | |||
| 66f8bd6869 | |||
| 64a8cd627a | |||
| 5d3923b670 |
@@ -1,6 +1,7 @@
|
||||
name: Benchmarking tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "30 1 1,15 * *" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM
|
||||
|
||||
|
||||
@@ -1,22 +1,58 @@
|
||||
name: Build Docker images (nightly)
|
||||
name: Test, build, and push Docker images
|
||||
|
||||
on:
|
||||
pull_request: # During PRs, we just check if the changes Dockerfiles can be successfully built
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docker/**"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # every day at midnight
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
REGISTRY: diffusers
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
build-docker-images:
|
||||
test-build-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
uses: jitterbit/get-changed-files@v1
|
||||
with:
|
||||
format: 'space-delimited'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build Changed Docker Images
|
||||
run: |
|
||||
CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
echo "Building Docker image for $DOCKER_TAG"
|
||||
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
|
||||
fi
|
||||
done
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
build-and-push-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
@@ -141,6 +141,7 @@ class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
|
||||
super().__init__(args)
|
||||
self.pipe.load_lora_weights(self.lora_id)
|
||||
self.pipe.fuse_lora()
|
||||
self.pipe.unload_lora_weights()
|
||||
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
||||
|
||||
def get_result_filepath(self, args):
|
||||
@@ -235,6 +236,35 @@ class InpaintingBenchmark(ImageToImageBenchmark):
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
|
||||
image = load_image(url)
|
||||
|
||||
def __init__(self, args):
|
||||
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.load_ip_adapter(
|
||||
args.ip_adapter_id[0],
|
||||
subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
|
||||
weight_name=args.ip_adapter_id[1],
|
||||
)
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
ip_adapter_image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
aux_network_class = ControlNetModel
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
IP_ADAPTER_CKPTS = {
|
||||
"runwayml/stable-diffusion-v1-5": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
choices=list(IP_ADAPTER_CKPTS.keys()),
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
|
||||
benchmark_pipe = IPAdapterTextToImageBenchmark(args)
|
||||
args.ckpt = f"{args.ckpt} (IP-Adapter)"
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -72,7 +72,7 @@ def main():
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
elif file == "benchmark_sd_inpainting.py":
|
||||
elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
|
||||
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
command = f"python {file} --ckpt {sdxl_ckpt}"
|
||||
run_command(command.split())
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
@@ -24,9 +24,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
"onnxruntime-gpu>=1.13.1" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m uv pip install --no-cache-dir \
|
||||
|
||||
@@ -40,6 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
transformers matplotlib
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
title: Load LoRAs for inference
|
||||
- local: tutorials/fast_diffusion
|
||||
title: Accelerate inference of text-to-image diffusion models
|
||||
title: Tutorials
|
||||
@@ -62,6 +62,8 @@
|
||||
title: Textual inversion
|
||||
- local: using-diffusers/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: using-diffusers/merge_loras
|
||||
title: Merge LoRAs
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference with multiple GPUs
|
||||
- local: using-diffusers/reusing_seeds
|
||||
@@ -318,6 +320,8 @@
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
@@ -418,6 +422,8 @@
|
||||
title: ScoreSdeVeScheduler
|
||||
- local: api/schedulers/score_sde_vp
|
||||
title: ScoreSdeVpScheduler
|
||||
- local: api/schedulers/tcd
|
||||
title: TCDScheduler
|
||||
- local: api/schedulers/unipc
|
||||
title: UniPCMultistepScheduler
|
||||
- local: api/schedulers/vq_diffusion
|
||||
|
||||
@@ -23,3 +23,7 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
|
||||
## IPAdapterMaskProcessor
|
||||
|
||||
[[autodoc]] image_processor.IPAdapterMaskProcessor
|
||||
@@ -0,0 +1,88 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Cascade
|
||||
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
Diffusion 1.5.
|
||||
|
||||
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
|
||||
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
|
||||
|
||||
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
|
||||
|
||||
## Model Overview
|
||||
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
|
||||
hence the name "Stable Cascade".
|
||||
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
for generating the small 24 x 24 latents given a text prompt.
|
||||
|
||||
## Uses
|
||||
|
||||
### Direct Use
|
||||
|
||||
The model is intended for research purposes for now. Possible research areas and tasks include
|
||||
|
||||
- Research on generative models.
|
||||
- Safe deployment of models which have the potential to generate harmful content.
|
||||
- Probing and understanding the limitations and biases of generative models.
|
||||
- Generation of artworks and use in design and other artistic processes.
|
||||
- Applications in educational or creative tools.
|
||||
|
||||
Excluded uses are described below.
|
||||
|
||||
### Out-of-Scope Use
|
||||
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
||||
|
||||
## Limitations and Bias
|
||||
|
||||
### Limitations
|
||||
- Faces and people in general may not be generated properly.
|
||||
- The autoencoding part of the model is lossy.
|
||||
|
||||
|
||||
## StableCascadeCombinedPipeline
|
||||
|
||||
[[autodoc]] StableCascadeCombinedPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipeline
|
||||
|
||||
[[autodoc]] StableCascadePriorPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
|
||||
|
||||
## StableCascadeDecoderPipeline
|
||||
|
||||
[[autodoc]] StableCascadeDecoderPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# TCDScheduler
|
||||
|
||||
[Trajectory Consistency Distillation](https://huggingface.co/papers/2402.19159) by Jianbin Zheng, Minghui Hu, Zhongyi Fan, Chaoyue Wang, Changxing Ding, Dacheng Tao and Tat-Jen Cham introduced a Strategic Stochastic Sampling (Algorithm 4) that is capable of generating good samples in a small number of steps. Distinguishing it as an advanced iteration of the multistep scheduler (Algorithm 1) in the [Consistency Models](https://huggingface.co/papers/2303.01469), Strategic Stochastic Sampling specifically tailored for the trajectory consistency function.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Latent Consistency Model (LCM) extends the Consistency Model to the latent space and leverages the guided consistency distillation technique to achieve impressive performance in accelerating text-to-image synthesis. However, we observed that LCM struggles to generate images with both clarity and detailed intricacy. To address this limitation, we initially delve into and elucidate the underlying causes. Our investigation identifies that the primary issue stems from errors in three distinct areas. Consequently, we introduce Trajectory Consistency Distillation (TCD), which encompasses trajectory consistency function and strategic stochastic sampling. The trajectory consistency function diminishes the distillation errors by broadening the scope of the self-consistency boundary condition and endowing the TCD with the ability to accurately trace the entire trajectory of the Probability Flow ODE. Additionally, strategic stochastic sampling is specifically designed to circumvent the accumulated errors inherent in multi-step consistency sampling, which is meticulously tailored to complement the TCD model. Experiments demonstrate that TCD not only significantly enhances image quality at low NFEs but also yields more detailed results compared to the teacher model at high NFEs.*
|
||||
|
||||
The original codebase can be found at [jabir-zheng/TCD](https://github.com/jabir-zheng/TCD).
|
||||
|
||||
## TCDScheduler
|
||||
[[autodoc]] TCDScheduler
|
||||
|
||||
|
||||
## TCDSchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_tcd.TCDSchedulerOutput
|
||||
|
||||
@@ -77,7 +77,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```bash
|
||||
```py
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -170,7 +170,7 @@ Aside from setting up the LoRA layers, the training script is more or less the s
|
||||
|
||||
Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
|
||||
|
||||
Let's train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate our yown Pokémon. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and dataset respectively. You should also specify where to save the model in `OUTPUT_DIR`, and the name of the model to save to on the Hub with `HUB_MODEL_ID`. The script creates and saves the following files to your repository:
|
||||
Let's train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate our own Pokémon. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and dataset respectively. You should also specify where to save the model in `OUTPUT_DIR`, and the name of the model to save to on the Hub with `HUB_MODEL_ID`. The script creates and saves the following files to your repository:
|
||||
|
||||
- saved model checkpoints
|
||||
- `pytorch_lora_weights.safetensors` (the trained LoRA weights)
|
||||
|
||||
@@ -14,19 +14,17 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Load LoRAs for inference
|
||||
|
||||
There are many adapters (with LoRAs being the most common type) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) for inference.
|
||||
There are many adapter types (with [LoRAs](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) being the most popular) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images.
|
||||
|
||||
Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora).
|
||||
In this tutorial, you'll learn how to easily load and manage adapters for inference with the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers. You'll use LoRA as the main adapter technique, so you'll see the terms LoRA and adapter used interchangeably.
|
||||
|
||||
Let's first install all the required libraries.
|
||||
|
||||
```bash
|
||||
!pip install -q transformers accelerate
|
||||
!pip install peft
|
||||
!pip install diffusers
|
||||
!pip install -q transformers accelerate peft diffusers
|
||||
```
|
||||
|
||||
Now, let's load a pipeline with a SDXL checkpoint:
|
||||
Now, load a pipeline with a [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) checkpoint:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -36,16 +34,13 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
|
||||
Next, load a LoRA checkpoint with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method.
|
||||
|
||||
With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
```
|
||||
|
||||
And then perform inference:
|
||||
Make sure to include the token `toy_face` in the prompt and then you can perform inference:
|
||||
|
||||
```python
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
@@ -59,17 +54,16 @@ image
|
||||
|
||||

|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images, and let's call it `"pixel"`.
|
||||
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter. But you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method as shown below:
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.set_adapters("pixel")
|
||||
```
|
||||
|
||||
Let's now generate an image with the second adapter and check the result:
|
||||
Make sure you include the token `pixel art` in your prompt to generate a pixel art image:
|
||||
|
||||
```python
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
@@ -81,29 +75,25 @@ image
|
||||
|
||||

|
||||
|
||||
## Combine multiple adapters
|
||||
## Merge adapters
|
||||
|
||||
You can also perform multi-adapter inference where you combine different adapter checkpoints for inference.
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate two LoRA checkpoints and specify the weight for how the checkpoints should be combined.
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
|
||||
```python
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
```
|
||||
|
||||
Now that we have set these two adapters, let's generate an image from the combined adapters!
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts.
|
||||
|
||||
</Tip>
|
||||
|
||||
The trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) are found in their repositories.
|
||||
|
||||
Remember to use the trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) (these are found in their repositories) in the prompt to generate an image.
|
||||
|
||||
```python
|
||||
# Notice how the prompt is constructed.
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)
|
||||
@@ -113,15 +103,16 @@ image
|
||||
|
||||

|
||||
|
||||
Impressive! As you can see, the model was able to generate an image that mixes the characteristics of both adapters.
|
||||
Impressive! As you can see, the model generated an image that mixed the characteristics of both adapters.
|
||||
|
||||
If you want to go back to using only one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
> [!TIP]
|
||||
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
|
||||
|
||||
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
|
||||
```python
|
||||
# First, set the adapter.
|
||||
pipe.set_adapters("toy")
|
||||
|
||||
# Then, run inference.
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
lora_scale= 0.9
|
||||
image = pipe(
|
||||
@@ -130,11 +121,7 @@ image = pipe(
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
If you want to switch to only the base model, disable all LoRAs with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method.
|
||||
|
||||
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
|
||||
|
||||
```python
|
||||
pipe.disable_lora()
|
||||
@@ -145,11 +132,9 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
## Manage active adapters
|
||||
|
||||
## Monitoring active adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, you can easily check the list of active adapters using the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method:
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
```py
|
||||
active_adapters = pipe.get_active_adapters()
|
||||
@@ -164,74 +149,3 @@ list_adapters_component_wise = pipe.get_list_adapters()
|
||||
list_adapters_component_wise
|
||||
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
|
||||
## Compatibility with `torch.compile`
|
||||
|
||||
If you want to compile your model with `torch.compile` make sure to first fuse the LoRA weights into the base model and unload them.
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe = torch.compile(pipe)
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Fusing adapters into the model
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora()
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
```
|
||||
|
||||
You can also fuse some adapters using `adapter_names` for faster generation:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora(adapter_names=["pixel"])
|
||||
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Fuse all adapters
|
||||
pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Saving a pipeline after fusing the adapters
|
||||
|
||||
To properly save a pipeline after it's been loaded with the adapters, it should be serialized like so:
|
||||
|
||||
```python
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.save_pretrained("path-to-pipeline")
|
||||
```
|
||||
|
||||
@@ -12,13 +12,18 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Pipeline callbacks
|
||||
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. This can be really useful for *dynamically* adjusting certain pipeline attributes, or modifying tensor variables. The flexibility of callbacks opens up some interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale.
|
||||
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
|
||||
|
||||
This guide will show you how to use the `callback_on_step_end` parameter to disable classifier-free guidance (CFG) after 40% of the inference steps to save compute with minimal cost to performance.
|
||||
> [!TIP]
|
||||
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
The callback function should have the following arguments:
|
||||
This guide will demonstrate how callbacks work by a few features you can implement with them.
|
||||
|
||||
* `pipe` (or the pipeline instance) provides access to useful properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipe._guidance_scale=0.0`.
|
||||
## Dynamic classifier-free guidance
|
||||
|
||||
Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:
|
||||
|
||||
* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
|
||||
* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
|
||||
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
|
||||
|
||||
@@ -27,13 +32,13 @@ Your callback function should look something like this:
|
||||
```python
|
||||
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
||||
# adjust the batch_size of prompt_embeds according to guidance_scale
|
||||
if step_index == int(pipe.num_timesteps * 0.4):
|
||||
if step_index == int(pipeline.num_timesteps * 0.4):
|
||||
prompt_embeds = callback_kwargs["prompt_embeds"]
|
||||
prompt_embeds = prompt_embeds.chunk(2)[-1]
|
||||
|
||||
# update guidance_scale and prompt_embeds
|
||||
pipe._guidance_scale = 0.0
|
||||
callback_kwargs["prompt_embeds"] = prompt_embeds
|
||||
# update guidance_scale and prompt_embeds
|
||||
pipeline._guidance_scale = 0.0
|
||||
callback_kwargs["prompt_embeds"] = prompt_embeds
|
||||
return callback_kwargs
|
||||
```
|
||||
|
||||
@@ -43,58 +48,134 @@ Now, you can pass the callback function to the `callback_on_step_end` parameter
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(1)
|
||||
out = pipe(prompt, generator=generator, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
|
||||
out = pipeline(
|
||||
prompt,
|
||||
generator=generator,
|
||||
callback_on_step_end=callback_dynamic_cfg,
|
||||
callback_on_step_end_tensor_inputs=['prompt_embeds']
|
||||
)
|
||||
|
||||
out.images[0].save("out_custom_cfg.png")
|
||||
```
|
||||
|
||||
The callback function is executed at the end of each denoising step, and modifies the pipeline attributes and tensor variables for the next denoising step.
|
||||
|
||||
With callbacks, you can implement features such as dynamic CFG without having to modify the underlying code at all!
|
||||
|
||||
<Tip>
|
||||
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
## Interrupt the diffusion process
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
> [!TIP]
|
||||
> The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
<Tip>
|
||||
Stopping the diffusion process early is useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
|
||||
|
||||
</Tip>
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
This callback function should take the following arguments: `pipeline`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
def interrupt_callback(pipeline, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
pipeline._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
pipeline(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
## Display image after each generation step
|
||||
|
||||
> [!TIP]
|
||||
> This tip was contributed by [asomoza](https://github.com/asomoza).
|
||||
|
||||
Display an image after each generation step by accessing and converting the latents after each step into an image. The latent space is compressed to 128x128, so the images are also 128x128 which is useful for a quick preview.
|
||||
|
||||
1. Use the function below to convert the SDXL latents (4 channels) to RGB tensors (3 channels) as explained in the [Explaining the SDXL latent space](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) blog post.
|
||||
|
||||
```py
|
||||
def latents_to_rgb(latents):
|
||||
weights = (
|
||||
(60, -60, 25, -70),
|
||||
(60, -5, 15, -50),
|
||||
(60, 10, -5, -35)
|
||||
)
|
||||
|
||||
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
|
||||
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
|
||||
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
|
||||
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
return Image.fromarray(image_array)
|
||||
```
|
||||
|
||||
2. Create a function to decode and save the latents into an image.
|
||||
|
||||
```py
|
||||
def decode_tensors(pipe, step, timestep, callback_kwargs):
|
||||
latents = callback_kwargs["latents"]
|
||||
|
||||
image = latents_to_rgb(latents)
|
||||
image.save(f"{step}.png")
|
||||
|
||||
return callback_kwargs
|
||||
```
|
||||
|
||||
3. Pass the `decode_tensors` function to the `callback_on_step_end` parameter to decode the tensors after each step. You also need to specify what you want to modify in the `callback_on_step_end_tensor_inputs` parameter, which in this case are the latents.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt = "A croissant shaped like a cute bear."
|
||||
negative_prompt = "Deformed, ugly, bad anatomy"
|
||||
callback_on_step_end=decode_tensors,
|
||||
callback_on_step_end_tensor_inputs=["latents"],
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4 justify-center">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 0</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_19.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 19
|
||||
</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_29.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 29</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_39.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 39</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/tips_step_49.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">step 49</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -429,6 +429,27 @@ image = pipe(
|
||||
make_image_grid([original_image, canny_image, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
You can use a refiner model with `StableDiffusionXLControlNetPipeline` to improve image quality, just like you can with a regular `StableDiffusionXLPipeline`.
|
||||
See the [Refine image quality](./sdxl#refine-image-quality) section to learn how to use the refiner model.
|
||||
Make sure to use `StableDiffusionXLControlNetPipeline` and pass `image` and `controlnet_conditioning_scale`.
|
||||
|
||||
```py
|
||||
base = StableDiffusionXLControlNetPipeline(...)
|
||||
image = base(
|
||||
prompt=prompt,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
image=canny_image,
|
||||
num_inference_steps=40,
|
||||
denoising_end=0.8,
|
||||
output_type="latent",
|
||||
).images
|
||||
# rest exactly as with StableDiffusionXLPipeline
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## MultiControlNet
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -128,7 +128,7 @@ seed = 2023
|
||||
# The values come from
|
||||
# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
|
||||
pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
|
||||
video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
|
||||
video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames[0]
|
||||
export_to_video(video_frames, "astronaut_rides_horse.mp4")
|
||||
```
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ Let's take a look at how to use IP-Adapter's image prompting capabilities with t
|
||||
|
||||
In all the following examples, you'll see the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method. This method controls the amount of text or image conditioning to apply to the model. A value of `1.0` means the model is only conditioned on the image prompt. Lowering this value encourages the model to produce more diverse images, but they may not be as aligned with the image prompt. Typically, a value of `0.5` achieves a good balance between the two prompt types and produces good results.
|
||||
|
||||
> [!TIP]
|
||||
> In the examples below, try adding `low_cpu_mem_usage=True` to the [`~loaders.IPAdapterMixin.load_ip_adapter`] method to speed up the loading time.
|
||||
|
||||
<hfoptions id="tasks">
|
||||
<hfoption id="Text-to-image">
|
||||
|
||||
@@ -231,10 +234,21 @@ export_to_gif(frames, "gummy_bear.gif")
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!TIP]
|
||||
> While calling `load_ip_adapter()`, pass `low_cpu_mem_usage=True` to speed up the loading time.
|
||||
## Configure parameters
|
||||
|
||||
All the pipelines supporting IP-Adapter accept a `ip_adapter_image_embeds` argument. If you need to run the IP-Adapter multiple times with the same image, you can encode the image once and save the embedding to the disk.
|
||||
There are a couple of IP-Adapter parameters that are useful to know about and can help you with your image generation tasks. These parameters can make your workflow more efficient or give you more control over image generation.
|
||||
|
||||
### Image embeddings
|
||||
|
||||
IP-Adapter enabled pipelines provide the `ip_adapter_image_embeds` parameter to accept precomputed image embeddings. This is particularly useful in scenarios where you need to run the IP-Adapter pipeline multiple times because you have more than one image. For example, [multi IP-Adapter](#multi-ip-adapter) is a specific use case where you provide multiple styling images to generate a specific image in a specific style. Loading and encoding multiple images each time you use the pipeline would be inefficient. Instead, you can precompute and save the image embeddings to disk (which can save a lot of space if you're using high-quality images) and load them when you need them.
|
||||
|
||||
> [!TIP]
|
||||
> This parameter also gives you the flexibility to load embeddings from other sources. For example, ComfyUI image embeddings for IP-Adapters are compatible with Diffusers and should work ouf-of-the-box!
|
||||
|
||||
Call the [`~StableDiffusionPipeline.prepare_ip_adapter_image_embeds`] method to encode and generate the image embeddings. Then you can save them to disk with `torch.save`.
|
||||
|
||||
> [!TIP]
|
||||
> If you're using IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`', you can set `load_ip_adapter(image_encoder_folder=None,...)` because you don't need to load an encoder to generate the image embeddings.
|
||||
|
||||
```py
|
||||
image_embeds = pipeline.prepare_ip_adapter_image_embeds(
|
||||
@@ -248,10 +262,7 @@ image_embeds = pipeline.prepare_ip_adapter_image_embeds(
|
||||
torch.save(image_embeds, "image_embeds.ipadpt")
|
||||
```
|
||||
|
||||
Load the image embedding and pass it to the pipeline as `ip_adapter_image_embeds`
|
||||
|
||||
> [!TIP]
|
||||
> ComfyUI image embeddings for IP-Adapters are fully compatible in Diffusers and should work out-of-box.
|
||||
Now load the image embeddings by passing them to the `ip_adapter_image_embeds` parameter.
|
||||
|
||||
```py
|
||||
image_embeds = torch.load("image_embeds.ipadpt")
|
||||
@@ -264,8 +275,86 @@ images = pipeline(
|
||||
).images
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you use IP-Adapter with `ip_adapter_image_embedding` instead of `ip_adapter_image`, you can choose not to load an image encoder by passing `image_encoder_folder=None` to `load_ip_adapter()`.
|
||||
### IP-Adapter masking
|
||||
|
||||
Binary masks specify which portion of the output image should be assigned to an IP-Adapter. This is useful for composing more than one IP-Adapter image. For each input IP-Adapter image, you must provide a binary mask an an IP-Adapter.
|
||||
|
||||
To start, preprocess the input IP-Adapter images with the [`~image_processor.IPAdapterMaskProcessor.preprocess()`] to generate their masks. For optimal results, provide the output height and width to [`~image_processor.IPAdapterMaskProcessor.preprocess()`]. This ensures masks with different aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, you don't have to set the `height` and `width`.
|
||||
|
||||
```py
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
|
||||
mask1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask1.png")
|
||||
mask2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask2.png")
|
||||
|
||||
output_height = 1024
|
||||
output_width = 1024
|
||||
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
When there is more than one input IP-Adapter image, load them as a list to ensure each image is assigned to a different IP-Adapter. Each of the input IP-Adapter images here correspond to the masks generated above.
|
||||
|
||||
```py
|
||||
face_image1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
|
||||
face_image2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png")
|
||||
|
||||
ip_images = [[face_image1], [face_image2]]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Now pass the preprocessed masks to `cross_attention_kwargs` in the pipeline call.
|
||||
|
||||
```py
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
|
||||
pipeline.set_ip_adapter_scale([0.7] * 2)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
num_images = 1
|
||||
|
||||
image = pipeline(
|
||||
prompt="2 girls",
|
||||
ip_adapter_image=ip_images,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20,
|
||||
num_images_per_prompt=num_images,
|
||||
generator=generator,
|
||||
cross_attention_kwargs={"ip_adapter_masks": masks}
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter masking applied</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_no_attention_mask_result_seed_0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">no IP-Adapter masking applied</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Specific use cases
|
||||
|
||||
@@ -279,6 +368,7 @@ Generating accurate faces is challenging because they are complex and nuanced. D
|
||||
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
|
||||
|
||||
> [!TIP]
|
||||
>
|
||||
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
|
||||
|
||||
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
|
||||
@@ -502,82 +592,3 @@ image
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
|
||||
</div>
|
||||
|
||||
### IP-Adapter masking
|
||||
|
||||
Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter.
|
||||
For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided.
|
||||
|
||||
Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`].
|
||||
|
||||
> [!TIP]
|
||||
> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted.
|
||||
|
||||
Here an example with two masks:
|
||||
|
||||
```py
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
|
||||
mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
|
||||
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")
|
||||
|
||||
output_height = 1024
|
||||
output_width = 1024
|
||||
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter.
|
||||
|
||||
```py
|
||||
face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
|
||||
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
|
||||
|
||||
ip_images = [[face_image1], [face_image2]]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image one</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image two</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below:
|
||||
|
||||
```py
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
|
||||
pipeline.set_ip_adapter_scale([0.7] * 2)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
num_images = 1
|
||||
|
||||
image = pipeline(
|
||||
prompt="2 girls",
|
||||
ip_adapter_image=ip_images,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=num_images,
|
||||
generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
|
||||
</div>
|
||||
|
||||
@@ -103,7 +103,7 @@ image
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA.
|
||||
LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA. It is also increasingly common to load and merge multiple LoRAs to create new and unique images. You can learn more about it in the in-depth [Merge LoRAs](merge_loras) guide since merging is outside the scope of this loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -165,101 +165,14 @@ To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weigh
|
||||
pipeline.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Load multiple LoRAs
|
||||
|
||||
It can be fun to use multiple LoRAs together to create something entirely new and unique. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights with the original weights of the underlying model.
|
||||
|
||||
<Tip>
|
||||
|
||||
Fusing the weights can lead to a speedup in inference latency because you don't need to separately load the base model and LoRA! You can save your fused pipeline with [`~DiffusionPipeline.save_pretrained`] to avoid loading and fusing the weights every time you want to use the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Load an initial model:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
|
||||
import torch
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
vae=vae,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Next, load the LoRA checkpoint and fuse it with the original weights. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won't work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
|
||||
If you need to reset the original model weights for any reason (use a different `lora_scale`), you should use the [`~loaders.LoraLoaderMixin.unfuse_lora`] method.
|
||||
|
||||
```py
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
|
||||
# to unfuse the LoRA weights
|
||||
pipeline.unfuse_lora()
|
||||
```
|
||||
|
||||
Then fuse this pipeline with the next set of LoRA weights:
|
||||
|
||||
```py
|
||||
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
You can't unfuse multiple LoRA checkpoints, so if you need to reset the model to its original weights, you'll need to reload it.
|
||||
|
||||
</Tip>
|
||||
|
||||
Now you can generate an image that uses the weights from both LoRAs:
|
||||
|
||||
```py
|
||||
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
|
||||
image = pipeline(prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### 🤗 PEFT
|
||||
|
||||
<Tip>
|
||||
|
||||
Read the [Inference with 🤗 PEFT](../tutorials/using_peft_for_inference) tutorial to learn more about its integration with 🤗 Diffusers and how you can easily work with and juggle multiple adapters. You'll need to install 🤗 Diffusers and PEFT from source to run the example in this section.
|
||||
|
||||
</Tip>
|
||||
|
||||
Another way you can load and use multiple LoRAs is to specify the `adapter_name` parameter in [`~loaders.LoraLoaderMixin.load_lora_weights`]. This method takes advantage of the 🤗 PEFT integration. For example, load and name both LoRA weights:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors", adapter_name="cereal")
|
||||
```
|
||||
|
||||
Now use the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] to activate both LoRAs, and you can configure how much weight each LoRA should have on the output:
|
||||
|
||||
```py
|
||||
pipeline.set_adapters(["ikea", "cereal"], adapter_weights=[0.7, 0.5])
|
||||
```
|
||||
|
||||
Then, generate an image:
|
||||
|
||||
```py
|
||||
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
|
||||
image = pipeline(prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
Let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/):
|
||||
<hfoptions id="other-trainers">
|
||||
<hfoption id="Kohya">
|
||||
|
||||
To load a Kohya LoRA, let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/) as an example:
|
||||
|
||||
```sh
|
||||
!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors
|
||||
@@ -293,6 +206,9 @@ Some limitations of using Kohya LoRAs with 🤗 Diffusers include:
|
||||
|
||||
</Tip>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="TheLastBen">
|
||||
|
||||
Loading a checkpoint from TheLastBen is very similar. For example, to load the [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) checkpoint:
|
||||
|
||||
```py
|
||||
@@ -308,6 +224,9 @@ image = pipeline(prompt=prompt).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[IP-Adapter](https://ip-adapter.github.io/) is a lightweight adapter that enables image prompting for any diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Merge LoRAs
|
||||
|
||||
It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
|
||||
|
||||
This guide will show you how to merge LoRAs using the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.LoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
|
||||
|
||||
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style]() and [Norod78/sdxl-chalkboarddrawing-lora]() LoRAs with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
```
|
||||
|
||||
## set_adapters
|
||||
|
||||
The [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
|
||||
|
||||
```py
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
prompt = "A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
|
||||
image = pipeline(prompt, generator=generator, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_merge_set_adapters.png"/>
|
||||
</div>
|
||||
|
||||
## add_weighted_adapter
|
||||
|
||||
> [!WARNING]
|
||||
> This is an experimental method that adds PEFTs [`~peft.LoraModel.add_weighted_adapter`] method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
|
||||
|
||||
The [`~peft.LoraModel.add_weighted_adapter`] method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
|
||||
|
||||
```bash
|
||||
pip install -U diffusers peft
|
||||
```
|
||||
|
||||
There are three steps to merge LoRAs with the [`~peft.LoraModel.add_weighted_adapter`] method:
|
||||
|
||||
1. Create a [`~peft.PeftModel`] from the underlying model and LoRA checkpoint.
|
||||
2. Load a base UNet model and the LoRA adapters.
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice.
|
||||
|
||||
Let's dive deeper into what these steps entail.
|
||||
|
||||
1. Load a UNet that corresponds to the UNet in the LoRA checkpoint. In this case, both LoRAs use the SDXL UNet as their base model.
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
import torch
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
subfolder="unet",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Load the SDXL pipeline and the LoRA checkpoints, starting with the [ostris/ikea-instructions-lora-sdxl](https://huggingface.co/ostris/ikea-instructions-lora-sdxl) LoRA.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
variant="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
unet=unet
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
```
|
||||
|
||||
Now you'll create a [`~peft.PeftModel`] from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
|
||||
|
||||
```python
|
||||
from peft import get_peft_model, LoraConfig
|
||||
import copy
|
||||
|
||||
sdxl_unet = copy.deepcopy(unet)
|
||||
ikea_peft_model = get_peft_model(
|
||||
sdxl_unet,
|
||||
pipeline.unet.peft_config["ikea"],
|
||||
adapter_name="ikea"
|
||||
)
|
||||
|
||||
original_state_dict = {f"base_model.model.{k}": v for k, v in pipeline.unet.state_dict().items()}
|
||||
ikea_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
|
||||
|
||||
Repeat this process to create a [`~peft.PeftModel`] from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
|
||||
|
||||
```python
|
||||
pipeline.delete_adapters("ikea")
|
||||
sdxl_unet.delete_adapters("ikea")
|
||||
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
pipeline.set_adapters(adapter_names="feng")
|
||||
|
||||
feng_peft_model = get_peft_model(
|
||||
sdxl_unet,
|
||||
pipeline.unet.peft_config["feng"],
|
||||
adapter_name="feng"
|
||||
)
|
||||
|
||||
original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
|
||||
feng_peft_model.load_state_dict(original_state_dict, strict=True)
|
||||
```
|
||||
|
||||
2. Load a base UNet model and then load the adapters onto it.
|
||||
|
||||
```python
|
||||
from peft import PeftModel
|
||||
|
||||
base_unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
subfolder="unet",
|
||||
).to("cuda")
|
||||
|
||||
model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_safetensors=True, subfolder="ikea", adapter_name="ikea")
|
||||
model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
|
||||
```
|
||||
|
||||
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
|
||||
|
||||
> [!WARNING]
|
||||
> Keep in mind the LoRAs need to have the same rank to be merged!
|
||||
|
||||
```python
|
||||
model.add_weighted_adapter(
|
||||
adapters=["ikea", "feng"],
|
||||
weights=[1.0, 1.0],
|
||||
combination_type="dare_linear",
|
||||
adapter_name="ikea-feng"
|
||||
)
|
||||
model.set_adapters("ikea-feng")
|
||||
```
|
||||
|
||||
Now you can generate an image with the merged LoRA.
|
||||
|
||||
```python
|
||||
model = model.to(dtype=torch.float16, device="cuda")
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", unet=model, variant="fp16", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ikea-feng-dare-linear.png"/>
|
||||
</div>
|
||||
|
||||
## fuse_lora
|
||||
|
||||
Both the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
```
|
||||
|
||||
Fuse these LoRAs into the UNet with the [`~loaders.LoraLoaderMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won’t work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
|
||||
|
||||
```py
|
||||
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
|
||||
```
|
||||
|
||||
Then you should use [`~loaders.LoraLoaderMixin.unload_lora_weights`] to unload the LoRA weights since they've already been fused with the underlying base model. Finally, call [`~DiffusionPipeline.save_pretrained`] to save the fused pipeline locally or you could call [`~DiffusionPipeline.push_to_hub`] to push the fused pipeline to the Hub.
|
||||
|
||||
```py
|
||||
pipeline.unload_lora_weights()
|
||||
# save locally
|
||||
pipeline.save_pretrained("path/to/fused-pipeline")
|
||||
# save to the Hub
|
||||
pipeline.push_to_hub("fused-ikea-feng")
|
||||
```
|
||||
|
||||
Now you can quickly load the fused pipeline and use it for inference without needing to separately load the LoRA adapters.
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"username/fused-ikea-feng", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can call [`~loaders.LoraLoaderMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
|
||||
|
||||
```py
|
||||
pipeline.unfuse_lora()
|
||||
```
|
||||
|
||||
### torch.compile
|
||||
|
||||
[torch.compile](../optimization/torch2.0#torchcompile) can speed up your pipeline even more, but the LoRA weights must be fused first and then unloaded. Typically, the UNet is compiled because it is such a computationally intensive component of the pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# load base model and LoRAs
|
||||
pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
|
||||
pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
|
||||
|
||||
# activate both LoRAs and set adapter weights
|
||||
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
|
||||
|
||||
# fuse LoRAs and unload weights
|
||||
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
# torch.compile
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
Learn more about torch.compile in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion#torchcompile) guide.
|
||||
|
||||
## Next steps
|
||||
|
||||
For more conceptual details about how each merging method works, take a look at the [🤗 PEFT welcomes new merging methods](https://huggingface.co/blog/peft_merging#concatenation-cat) blog post!
|
||||
@@ -273,7 +273,6 @@ Lastly, convert the image to a `PIL.Image` to see your generated image!
|
||||
```py
|
||||
>>> image = (image / 2 + 0.5).clamp(0, 1).squeeze()
|
||||
>>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
||||
>>> image = (image * 255).round().astype("uint8")
|
||||
>>> image = Image.fromarray(image)
|
||||
>>> image
|
||||
```
|
||||
|
||||
@@ -80,8 +80,7 @@ To do so, just specify `--train_text_encoder_ti` while launching training (for r
|
||||
Please keep the following points in mind:
|
||||
|
||||
* SDXL has two text encoders. So, we fine-tune both using LoRA.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
|
||||
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
|
||||
|
||||
### 3D icon example
|
||||
|
||||
@@ -234,6 +233,32 @@ In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
### DoRA training
|
||||
The advanced script now supports DoRA training too!
|
||||
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
|
||||
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
|
||||
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
|
||||
|
||||
> [!NOTE]
|
||||
> 💡DoRA training is still _experimental_
|
||||
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
|
||||
> Specifically, we've noticed 2 differences to take into account your training:
|
||||
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
|
||||
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
|
||||
> This is also aligned with some of the quantitative analysis shown in the paper.
|
||||
|
||||
**Usage**
|
||||
1. To use DoRA you need to install `peft` from main:
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
2. Enable DoRA training by adding this flag
|
||||
```bash
|
||||
--use_dora
|
||||
```
|
||||
**Inference**
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
@@ -77,6 +77,7 @@ logger = get_logger(__name__)
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
@@ -88,6 +89,7 @@ def save_model_card(
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
lora = "lora" if not use_dora else "dora"
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
@@ -139,9 +141,10 @@ to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- diffusers-training
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
- {lora}
|
||||
- template:sd-lora
|
||||
{img_str}
|
||||
base_model: {base_model}
|
||||
@@ -651,6 +654,16 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_latents",
|
||||
action="store_true",
|
||||
@@ -1219,6 +1232,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
@@ -1230,6 +1244,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
@@ -1955,6 +1970,7 @@ def main(args):
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
use_dora=args.use_dora,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
|
||||
@@ -81,6 +81,7 @@ logger = get_logger(__name__)
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
@@ -92,6 +93,7 @@ def save_model_card(
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
lora = "lora" if not use_dora else "dora"
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
@@ -144,9 +146,10 @@ to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- diffusers-training
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
- {lora}
|
||||
- template:sd-lora
|
||||
{img_str}
|
||||
base_model: {base_model}
|
||||
@@ -661,6 +664,15 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_latents",
|
||||
action="store_true",
|
||||
@@ -1323,6 +1335,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
@@ -1334,6 +1347,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
@@ -2192,6 +2206,7 @@ def main(args):
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
use_dora=args.use_dora,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
|
||||
@@ -105,7 +105,7 @@ pipeline_output = pipe(
|
||||
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
|
||||
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
|
||||
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral".
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
|
||||
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
|
||||
)
|
||||
|
||||
@@ -750,7 +750,7 @@ This example produces the following images:
|
||||

|
||||
|
||||
### GlueGen Stable Diffusion Pipeline
|
||||
GlueGen is a minimal adapter that allow alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.
|
||||
GlueGen is a minimal adapter that allow alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.
|
||||
|
||||
Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main)
|
||||
|
||||
@@ -782,9 +782,9 @@ if __name__ == "__main__":
|
||||
).to(device)
|
||||
pipeline.load_language_adapter("gluenet_French_clip_overnorm_over3_noln.ckpt", num_token=token_max_length, dim=1024, dim_out=768, tensor_norm=tensor_norm)
|
||||
|
||||
prompt = "une voiture sur la plage"
|
||||
prompt = "une voiture sur la plage"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image.save("gluegen_output_fr.png")
|
||||
```
|
||||
@@ -1755,7 +1755,7 @@ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
||||
```
|
||||
|
||||
The following code compares the performance of the original stable diffusion xl pipeline with the ipex-optimized pipeline.
|
||||
By using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,
|
||||
By using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,
|
||||
code-named Sapphire Rapids.
|
||||
|
||||
```python
|
||||
@@ -1826,7 +1826,7 @@ This approach is using (optional) CoCa model to avoid writing image description.
|
||||
|
||||
This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
|
||||
You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -3397,7 +3397,7 @@ invert_prompt = "A lying cat"
|
||||
input_image = "siamese.jpg"
|
||||
steps = 50
|
||||
|
||||
# Provide prompt used for generation. Same if reconstruction
|
||||
# Provide prompt used for generation. Same if reconstruction
|
||||
prompt = "A lying cat"
|
||||
# or different if editing.
|
||||
prompt = "A lying dog"
|
||||
@@ -3414,15 +3414,13 @@ pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_step
|
||||
|
||||
### Rerender A Video
|
||||
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`:
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
import sys
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
```
|
||||
sys.path.insert(0, gmflow_dir)
|
||||
|
||||
After that, you can run the pipeline with:
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoencoderKL, DDIMScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
import numpy as np
|
||||
@@ -3493,7 +3491,7 @@ output_frames = pipe(
|
||||
mask_end=0.8,
|
||||
mask_strength=0.5,
|
||||
negative_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
||||
).frames
|
||||
).frames[0]
|
||||
|
||||
export_to_video(
|
||||
output_frames, "/path/to/video.mp4", 5)
|
||||
@@ -3636,8 +3634,8 @@ image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
||||
images = pipeline(
|
||||
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
|
||||
image_embeds=image,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
|
||||
@@ -50,14 +50,14 @@ class MarigoldDepthOutput(BaseOutput):
|
||||
Args:
|
||||
depth_np (`np.ndarray`):
|
||||
Predicted depth map, with depth values in the range of [0, 1].
|
||||
depth_colored (`PIL.Image.Image`):
|
||||
depth_colored (`None` or `PIL.Image.Image`):
|
||||
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
||||
uncertainty (`None` or `np.ndarray`):
|
||||
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
||||
"""
|
||||
|
||||
depth_np: np.ndarray
|
||||
depth_colored: Image.Image
|
||||
depth_colored: Union[None, Image.Image]
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
@@ -139,14 +139,15 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`):
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
||||
Colormap used to colorize the depth map.
|
||||
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Arguments for detailed ensembling settings.
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
|
||||
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
|
||||
values in [0, 1]. None if `color_map` is `None`
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
@@ -233,12 +234,15 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
depth_pred = depth_pred.clip(0, 1)
|
||||
|
||||
# Colorize
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
if color_map is not None:
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
else:
|
||||
depth_colored_img = None
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.schedulers import (
|
||||
@@ -37,7 +37,7 @@ from diffusers.schedulers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -103,14 +101,18 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffControlNetPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffControlNetPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin
|
||||
):
|
||||
@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline(
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# Denoising loop
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return AnimateDiffControlNetPipelineOutput(frames=latents)
|
||||
|
||||
# Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffControlNetPipelineOutput(frames=video)
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
|
||||
@@ -158,10 +158,8 @@ def slerp(
|
||||
return v2
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline(
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 10. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 11. Offload all models
|
||||
|
||||
@@ -15,18 +15,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
LoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
@@ -43,34 +72,486 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class Prompt2PromptPipeline(StableDiffusionPipeline):
|
||||
class Prompt2PromptPipeline(
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from
|
||||
[`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for
|
||||
all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler
|
||||
([`SchedulerMixin`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
ip_adapter_image=None,
|
||||
ip_adapter_image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -21,6 +20,7 @@ import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from gmflow.gmflow import GMFlow
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -34,13 +34,6 @@ from diffusers.utils import BaseOutput, deprecate, logging
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
gmflow_dir = "/path/to/gmflow"
|
||||
sys.path.insert(0, gmflow_dir)
|
||||
from gmflow.gmflow import GMFlow # noqa: E402
|
||||
|
||||
from utils.utils import InputPadder # noqa: E402
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -119,11 +112,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False):
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):
|
||||
if image3 is None:
|
||||
image3 = image1
|
||||
padder = InputPadder(image1.shape, padding_factor=8)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))
|
||||
results_dict = flow_model(
|
||||
image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True
|
||||
)
|
||||
@@ -307,6 +300,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder=None,
|
||||
requires_safety_checker: bool = True,
|
||||
device=None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@@ -320,6 +314,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
image_encoder,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.to(device)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
@@ -374,7 +369,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
attention_type="swin",
|
||||
ffn_dim_expansion=4,
|
||||
num_transformer_layers=6,
|
||||
).to("cuda")
|
||||
).to(self.device)
|
||||
|
||||
checkpoint = torch.utils.model_zoo.load_url(
|
||||
"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
|
||||
@@ -928,13 +923,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
|
||||
|
||||
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
|
||||
self.flow_model, first_image, image[0], first_result, False
|
||||
self.flow_model, first_image, image[0], first_result, False, self.device
|
||||
)
|
||||
blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
|
||||
|
||||
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False, self.device
|
||||
)
|
||||
blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
|
||||
@@ -1176,3 +1171,24 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
return output_frames
|
||||
|
||||
return TextToVideoSDPipelineOutput(frames=output_frames)
|
||||
|
||||
|
||||
class InputPadder:
|
||||
"""Pads images such that dimensions are divisible by 8"""
|
||||
|
||||
def __init__(self, dims, mode="sintel", padding_factor=8):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor
|
||||
pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor
|
||||
if mode == "sintel":
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
||||
else:
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
||||
|
||||
def unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
|
||||
@@ -242,6 +242,7 @@ These are controlnet weights trained on {base_model} with new type of conditioni
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"controlnet",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ These are controlnet weights trained on {base_model} with new type of conditioni
|
||||
"diffusers",
|
||||
"controlnet",
|
||||
"jax-diffusers-event",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ These are controlnet weights trained on {base_model} with new type of conditioni
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"controlnet",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
|
||||
@@ -97,7 +97,14 @@ These are Custom Diffusion adaption weights for {base_model}. The weights were t
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["text-to-image", "diffusers", "stable-diffusion", "stable-diffusion-diffusers", "custom-diffusion"]
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"custom-diffusion",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
@@ -206,3 +206,66 @@ You can explore the results from a couple of our internal experiments by checkin
|
||||
## Running on a free-tier Colab Notebook
|
||||
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
|
||||
|
||||
## Conducting EDM-style training
|
||||
|
||||
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
|
||||
|
||||
For the SDXL model, simple set:
|
||||
|
||||
```diff
|
||||
+ --do_edm_style_training \
|
||||
```
|
||||
|
||||
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
|
||||
--instance_data_dir="dog" \
|
||||
--output_dir="dog-playground-lora" \
|
||||
--mixed_precision="fp16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-4 \
|
||||
--use_8bit_adam \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
> [!CAUTION]
|
||||
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
|
||||
|
||||
### DoRA training
|
||||
The script now supports DoRA training too!
|
||||
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
|
||||
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
|
||||
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
|
||||
|
||||
> [!NOTE]
|
||||
> 💡DoRA training is still _experimental_
|
||||
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
|
||||
> Specifically, we've noticed 2 differences to take into account your training:
|
||||
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
|
||||
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
|
||||
> This is also aligned with some of the quantitative analysis shown in the paper.
|
||||
|
||||
**Usage**
|
||||
1. To use DoRA you need to install `peft` from main:
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
2. Enable DoRA training by adding this flag
|
||||
```bash
|
||||
--use_dora
|
||||
```
|
||||
**Inference**
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
@@ -0,0 +1,99 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
|
||||
def test_dreambooth_lora_sdxl_with_edm(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--do_edm_style_training
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_playground(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
@@ -102,7 +102,7 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["text-to-image", "dreambooth"]
|
||||
tags = ["text-to-image", "dreambooth", "diffusers-training"]
|
||||
if isinstance(pipeline, StableDiffusionPipeline):
|
||||
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
|
||||
else:
|
||||
|
||||
@@ -106,7 +106,7 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
tags = ["text-to-image", "diffusers", "lora"]
|
||||
tags = ["text-to-image", "diffusers", "lora", "diffusers-training"]
|
||||
if isinstance(pipeline, StableDiffusionPipeline):
|
||||
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
|
||||
else:
|
||||
|
||||
@@ -14,8 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -32,7 +34,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from packaging import version
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
@@ -50,6 +52,8 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMEulerScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -76,8 +80,23 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
||||
model_index_filename = "model_index.json"
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
|
||||
else:
|
||||
model_index = hf_hub_download(
|
||||
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
|
||||
)
|
||||
|
||||
with open(model_index, "r") as f:
|
||||
scheduler_type = json.load(f)["scheduler"][1]
|
||||
return scheduler_type
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images=None,
|
||||
base_model: str = None,
|
||||
train_text_encoder=False,
|
||||
@@ -95,7 +114,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# SDXL LoRA DreamBooth - {repo_id}
|
||||
# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
@@ -119,11 +138,17 @@ Weights for this model are available in Safetensors format.
|
||||
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
if "playground" in base_model:
|
||||
model_description += """\n
|
||||
## License
|
||||
|
||||
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="openrail++",
|
||||
license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -131,15 +156,18 @@ Weights for this model are available in Safetensors format.
|
||||
)
|
||||
tags = [
|
||||
"text-to-image",
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers-training",
|
||||
"diffusers",
|
||||
"lora",
|
||||
"lora" if not use_dora else "dora",
|
||||
"template:sd-lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
if "playground" in base_model:
|
||||
tags.extend(["playground", "playground-diffusers"])
|
||||
else:
|
||||
tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
|
||||
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
@@ -159,23 +187,29 @@ def log_validation(
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with inference_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -334,6 +368,12 @@ def parse_args(input_args=None):
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_edm_style_training",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -607,6 +647,15 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -828,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
if with_prior_preservation:
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
prompts += [example["class_prompt"] for example in examples]
|
||||
original_sizes += [example["original_size"] for example in examples]
|
||||
crop_top_lefts += [example["crop_top_left"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
@@ -905,6 +956,9 @@ def main(args):
|
||||
" Please use `huggingface-cli login` to authenticate with the Hub."
|
||||
)
|
||||
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -1018,7 +1072,19 @@ def main(args):
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
|
||||
if "EDM" in scheduler_type:
|
||||
args.do_edm_style_training = True
|
||||
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
logger.info("Performing EDM-style training!")
|
||||
elif args.do_edm_style_training:
|
||||
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
logger.info("Performing EDM-style training!")
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1036,6 +1102,12 @@ def main(args):
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -1086,6 +1158,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@@ -1097,6 +1170,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
@@ -1178,7 +1252,7 @@ def main(args):
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
@@ -1433,7 +1507,12 @@ def main(args):
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
|
||||
tracker_name = (
|
||||
"dreambooth-lora-sd-xl"
|
||||
if "playground" not in args.pretrained_model_name_or_path
|
||||
else "dreambooth-lora-playground"
|
||||
)
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
@@ -1485,6 +1564,18 @@ def main(args):
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
|
||||
timesteps = timesteps.to(accelerator.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
@@ -1512,22 +1603,46 @@ def main(args):
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
if not args.do_edm_style_training:
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
|
||||
# instead of discrete timesteps, so here we sample indices to get the noise levels
|
||||
# from `scheduler.timesteps`
|
||||
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
|
||||
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if args.do_edm_style_training:
|
||||
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
|
||||
if "EDM" in scheduler_type:
|
||||
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
|
||||
else:
|
||||
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
@@ -1551,7 +1666,7 @@ def main(args):
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
@@ -1570,18 +1685,43 @@ def main(args):
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds_input,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
weighting = None
|
||||
if args.do_edm_style_training:
|
||||
# Similar to the input preconditioning, the model predictions are also preconditioned
|
||||
# on noised model inputs (before preconditioning) and the sigmas.
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
if "EDM" in scheduler_type:
|
||||
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
|
||||
else:
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
|
||||
noisy_model_input / (sigmas**2 + 1)
|
||||
)
|
||||
# We are not doing weighting here because it tends result in numerical problems.
|
||||
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
# There might be other alternatives for weighting as well:
|
||||
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
|
||||
if "EDM" not in scheduler_type:
|
||||
weighting = (sigmas**-2.0).float()
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
target = model_input if args.do_edm_style_training else noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
target = (
|
||||
model_input
|
||||
if args.do_edm_style_training
|
||||
else noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
@@ -1591,10 +1731,28 @@ def main(args):
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
prior_loss = prior_loss.mean()
|
||||
else:
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
if weighting is not None:
|
||||
loss = torch.mean(
|
||||
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
|
||||
target.shape[0], -1
|
||||
),
|
||||
1,
|
||||
)
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
@@ -1696,7 +1854,6 @@ def main(args):
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
images = log_validation(
|
||||
@@ -1770,6 +1927,7 @@ def main(args):
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
use_dora=args.use_dora,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
|
||||
@@ -81,6 +81,7 @@ tags:
|
||||
- kandinsky
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -65,6 +65,7 @@ tags:
|
||||
- kandinsky
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
|
||||
@@ -65,6 +65,7 @@ tags:
|
||||
- kandinsky
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
|
||||
@@ -82,6 +82,7 @@ tags:
|
||||
- kandinsky
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -460,6 +460,8 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- controlnet
|
||||
- diffusers-training
|
||||
- webdataset
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## SDXL Turbo training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-turbo-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--is_turbo --resolution 512 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
|
||||
|
||||
@@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
images = []
|
||||
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
||||
|
||||
guidance_scale = 5.0
|
||||
num_inference_steps = 25
|
||||
if args.is_turbo:
|
||||
guidance_scale = 0.0
|
||||
num_inference_steps = 4
|
||||
for prompt in VALIDATION_PROMPTS:
|
||||
with context:
|
||||
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
@@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
if is_final_validation:
|
||||
pipeline.disable_lora()
|
||||
no_lora_images = [
|
||||
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
|
||||
pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
for prompt in VALIDATION_PROMPTS
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -423,6 +433,11 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_turbo",
|
||||
action="store_true",
|
||||
help=("Use if tuning SDXL Turbo instead of SDXL"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -444,6 +459,9 @@ def parse_args(input_args=None):
|
||||
if args.dataset_name is None:
|
||||
raise ValueError("Must provide a `dataset_name`.")
|
||||
|
||||
if args.is_turbo:
|
||||
assert "turbo" in args.pretrained_model_name_or_path
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -560,6 +578,36 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
def enforce_zero_terminal_snr(scheduler):
|
||||
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
|
||||
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
|
||||
# Turbo needs zero terminal SNR
|
||||
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1 - scheduler.betas
|
||||
alphas_bar = alphas.cumprod(0)
|
||||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
# Shift so last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
# Scale so first timestep is back to old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
scheduler.alphas_cumprod = alphas_cumprod
|
||||
return
|
||||
|
||||
if args.is_turbo:
|
||||
enforce_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -909,6 +957,10 @@ def main(args):
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
||||
).repeat(2)
|
||||
if args.is_turbo:
|
||||
# Learn a 4 timestep schedule
|
||||
timesteps_0_to_3 = timesteps % 4
|
||||
timesteps = 250 * timesteps_0_to_3 + 249
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
@@ -69,6 +69,7 @@ tags:
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
|
||||
@@ -100,6 +100,8 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- textual_inversion
|
||||
- diffusers-training
|
||||
- onxruntime
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# PromptDiffusion Pipeline
|
||||
|
||||
From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/)
|
||||
|
||||
"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types."
|
||||
|
||||
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115).
|
||||
|
||||
Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion)
|
||||
|
||||
To convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/):
|
||||
|
||||
```bash
|
||||
python convert_original_promptdiffusion_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory
|
||||
```
|
||||
|
||||
To learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats).
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
from diffusers.utils import load_image
|
||||
from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
|
||||
from pipeline_prompt_diffusion import PromptDiffusionPipeline
|
||||
|
||||
|
||||
from PIL import ImageOps
|
||||
|
||||
image_a = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"))
|
||||
|
||||
image_b = load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true")
|
||||
query = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"))
|
||||
|
||||
# load prompt diffusion controlnet and prompt diffusion
|
||||
|
||||
controlnet = PromptDiffusionControlNetModel.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="controlnet", torch_dtype=torch.float16)
|
||||
model_id = "path-to-model"
|
||||
pipe = PromptDiffusionPipeline.from_pretrained("iczaw/prompt-diffusion-diffusers", subfolder="base", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16")
|
||||
|
||||
# speed up diffusion process with faster scheduler and memory optimization
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
# remove following line if xformers is not installed
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.enable_model_cpu_offload()
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe("a tortoise", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
|
||||
|
||||
```
|
||||
+2118
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,385 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.controlnet import (
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PromptDiffusionControlNetModel(ControlNetModel):
|
||||
"""
|
||||
A PromptDiffusionControlNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
conditioning_channels,
|
||||
flip_sin_to_cos,
|
||||
freq_shift,
|
||||
down_block_types,
|
||||
mid_block_type,
|
||||
only_cross_attention,
|
||||
block_out_channels,
|
||||
layers_per_block,
|
||||
downsample_padding,
|
||||
mid_block_scale_factor,
|
||||
act_fn,
|
||||
norm_num_groups,
|
||||
norm_eps,
|
||||
cross_attention_dim,
|
||||
transformer_layers_per_block,
|
||||
encoder_hid_dim,
|
||||
encoder_hid_dim_type,
|
||||
attention_head_dim,
|
||||
num_attention_heads,
|
||||
use_linear_projection,
|
||||
class_embed_type,
|
||||
addition_embed_type,
|
||||
addition_time_embed_dim,
|
||||
num_class_embeds,
|
||||
upcast_attention,
|
||||
resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels,
|
||||
global_pool_conditions,
|
||||
addition_embed_type_num_heads,
|
||||
)
|
||||
self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=3,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
controlnet_query_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`~PromptDiffusionControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_query_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond)
|
||||
sample = sample + controlnet_cond + controlnet_query_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
@@ -87,6 +87,7 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- realfill
|
||||
- diffusers-training
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -225,7 +225,14 @@ These are t2iadapter weights trained on {base_model} with new type of conditioni
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"]
|
||||
tags = [
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"t2iadapter",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
@@ -131,7 +131,7 @@ More information on all the CLI arguments and the environment are available on y
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"]
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
@@ -90,6 +90,7 @@ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
@@ -103,7 +103,14 @@ Special VAE used for training: {vae_path}.
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
|
||||
tags = [
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
@@ -35,7 +35,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
@@ -101,6 +101,7 @@ Special VAE used for training: {vae_path}.
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers-training",
|
||||
"diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
@@ -895,14 +896,20 @@ def main(args):
|
||||
# fingerprint used by the cache for the other processes to load the result
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
new_fingerprint_for_vae = Hasher.hash("vae")
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
train_dataset = train_dataset.map(
|
||||
new_fingerprint_for_vae = Hasher.hash(vae_path)
|
||||
train_dataset_with_embeddings = train_dataset.map(
|
||||
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
|
||||
)
|
||||
train_dataset_with_vae = train_dataset.map(
|
||||
compute_vae_encodings_fn,
|
||||
batched=True,
|
||||
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
|
||||
new_fingerprint=new_fingerprint_for_vae,
|
||||
)
|
||||
precomputed_dataset = concatenate_datasets(
|
||||
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
|
||||
)
|
||||
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
|
||||
|
||||
del text_encoders, tokenizers, vae
|
||||
gc.collect()
|
||||
@@ -925,7 +932,7 @@ def main(args):
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
precomputed_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
@@ -976,7 +983,7 @@ def main(args):
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num examples = {len(precomputed_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
|
||||
@@ -105,7 +105,14 @@ These are textual inversion adaption weights for {base_model}. You can find some
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"]
|
||||
tags = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"textual_inversion",
|
||||
"diffusers-training",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
@@ -106,6 +106,7 @@ These are textual inversion adaption weights for {base_model}. You can find some
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"diffusers-training",
|
||||
"textual_inversion",
|
||||
]
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
wandb
|
||||
huggingface-cli
|
||||
bitsandbytes
|
||||
deepspeed
|
||||
peft>=0.6.0
|
||||
|
||||
@@ -81,6 +81,7 @@ tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
|
||||
@@ -82,6 +82,7 @@ tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- diffusers-training
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -9,11 +9,11 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPip
|
||||
|
||||
ckpt_id = "PixArt-alpha/PixArt-alpha"
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
|
||||
interpolation_scale = {512: 1, 1024: 2}
|
||||
interpolation_scale = {256: 0.5, 512: 1, 1024: 2}
|
||||
|
||||
|
||||
def main(args):
|
||||
all_state_dict = torch.load(args.orig_ckpt_path)
|
||||
all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu")
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
@@ -22,7 +22,6 @@ def main(args):
|
||||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
@@ -155,6 +154,7 @@ def main(args):
|
||||
|
||||
assert transformer.pos_embed.pos_embed is not None
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
@@ -187,7 +187,7 @@ if __name__ == "__main__":
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024],
|
||||
choices=[256, 512, 1024],
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 512 or 1024.",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPConfig,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
||||
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
config.text_config.projection_dim = config.projection_dim
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# image processor
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
load_model_dict_into_meta(prior_model, state_dict)
|
||||
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
load_model_dict_into_meta(decoder, state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
|
||||
@@ -86,6 +86,7 @@ else:
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"PriorTransformer",
|
||||
"StableCascadeUNet",
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
@@ -160,6 +161,7 @@ else:
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"ScoreSdeVeScheduler",
|
||||
"TCDScheduler",
|
||||
"UnCLIPScheduler",
|
||||
"UniPCMultistepScheduler",
|
||||
"VQDiffusionScheduler",
|
||||
@@ -258,6 +260,9 @@ else:
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
@@ -545,6 +550,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
TCDScheduler,
|
||||
UnCLIPScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
@@ -624,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
|
||||
@@ -127,7 +127,7 @@ class ConfigMixin:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
||||
|
||||
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
||||
This function is mostly copied from PyTorch's __getattr__ overwrite:
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
@@ -259,6 +259,10 @@ class ConfigMixin:
|
||||
model = cls(**init_dict)
|
||||
|
||||
# make sure to also save config parameters that might be used for compatible classes
|
||||
# update _class_name
|
||||
if "_class_name" in hidden_dict:
|
||||
hidden_dict["_class_name"] = cls.__name__
|
||||
|
||||
model.register_to_config(**hidden_dict)
|
||||
|
||||
# add hidden kwargs of compatible classes to unused_kwargs
|
||||
@@ -529,7 +533,7 @@ class ConfigMixin:
|
||||
f"{cls.config_name} configuration file."
|
||||
)
|
||||
|
||||
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
||||
# 5. Give nice info if config attributes are initialized to default because they have not been passed
|
||||
passed_keys = set(init_dict.keys())
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.info(
|
||||
|
||||
@@ -332,7 +332,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
@@ -448,7 +448,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image: PipelineImageInput,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
resize_mode: str = "default", # "default", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -479,7 +479,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# if image is a pytorch tensor could have 2 possible shapes:
|
||||
# 1. batch x height x width: we should insert the channel dimension at position 1
|
||||
# 2. channnel x height x width: we should insert batch dimension at position 0,
|
||||
# 2. channel x height x width: we should insert batch dimension at position 0,
|
||||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
||||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
||||
image = image.unsqueeze(1)
|
||||
|
||||
@@ -215,7 +215,7 @@ class IPAdapterMixin:
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
|
||||
@@ -63,13 +63,20 @@ def build_sub_model_components(
|
||||
num_in_channels=num_in_channels,
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size,
|
||||
scaling_factor,
|
||||
torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return vae_components
|
||||
|
||||
@@ -124,11 +131,12 @@ def build_sub_model_components(
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint=None,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
@@ -181,6 +189,30 @@ class FromSingleFileMixin:
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
model_type (`str`, *optional*):
|
||||
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
|
||||
image_size (`int`, *optional*):
|
||||
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`.
|
||||
num_in_channels (`int`, *optional*):
|
||||
Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter
|
||||
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
|
||||
scaling_factor (`float`, *optional*):
|
||||
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first.
|
||||
If the scaling factor is not found in the config file, the default value 0.18215 is used.
|
||||
scheduler_type (`str`, *optional*):
|
||||
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file.
|
||||
prediction_type (`str`, *optional*):
|
||||
The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMDPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
@@ -175,6 +176,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -305,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -323,7 +325,9 @@ def infer_model_type(original_config, model_type=None):
|
||||
|
||||
elif has_network_config:
|
||||
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
|
||||
if context_dim == 2048:
|
||||
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
|
||||
model_type = "Playground"
|
||||
elif context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
@@ -344,13 +348,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
|
||||
return image_size
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
model_type = infer_model_type(original_config, model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
|
||||
return image_size
|
||||
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
|
||||
image_size = 1024
|
||||
return image_size
|
||||
|
||||
@@ -458,8 +462,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"down_block_types": down_block_types,
|
||||
"block_out_channels": block_out_channels,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
@@ -478,7 +482,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
config["num_class_embeds"] = unet_params["num_classes"]
|
||||
|
||||
config["out_channels"] = unet_params["out_channels"]
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
config["up_block_types"] = up_block_types
|
||||
|
||||
return config
|
||||
|
||||
@@ -506,12 +510,14 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
|
||||
return controlnet_config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
|
||||
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
|
||||
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
@@ -524,13 +530,15 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"down_block_types": down_block_types,
|
||||
"up_block_types": up_block_types,
|
||||
"block_out_channels": block_out_channels,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": scaling_factor,
|
||||
}
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
|
||||
|
||||
return config
|
||||
|
||||
@@ -1172,6 +1180,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
from ..models import UNet2DConditionModel
|
||||
|
||||
@@ -1190,7 +1199,9 @@ def create_diffusers_unet_model_from_ldm(
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
@@ -1223,14 +1234,40 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size=None,
|
||||
scaling_factor=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
|
||||
if model_type == "Playground":
|
||||
edm_mean = (
|
||||
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
|
||||
)
|
||||
edm_std = (
|
||||
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
|
||||
)
|
||||
else:
|
||||
edm_mean = None
|
||||
edm_std = None
|
||||
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config,
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
latents_mean=edm_mean,
|
||||
latents_std=edm_std,
|
||||
)
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
@@ -1265,7 +1302,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
):
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
@@ -1332,7 +1369,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
"text_encoder_2": text_encoder_2,
|
||||
}
|
||||
|
||||
elif model_type == "SDXL":
|
||||
elif model_type in ["SDXL", "Playground"]:
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
@@ -1383,7 +1420,7 @@ def create_scheduler_from_ldm(
|
||||
model_type=None,
|
||||
):
|
||||
scheduler_config = get_default_scheduler_config()
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
@@ -1406,7 +1443,8 @@ def create_scheduler_from_ldm(
|
||||
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
scheduler_type = "euler"
|
||||
|
||||
elif model_type == "Playground":
|
||||
scheduler_type = "edm_dpm_solver_multistep"
|
||||
else:
|
||||
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
|
||||
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
|
||||
@@ -1438,6 +1476,26 @@ def create_scheduler_from_ldm(
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "edm_dpm_solver_multistep":
|
||||
scheduler_config = {
|
||||
"algorithm_type": "dpmsolver++",
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"euler_at_final": False,
|
||||
"final_sigmas_type": "zero",
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"rho": 7.0,
|
||||
"sample_max_value": 1.0,
|
||||
"sigma_data": 0.5,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
"solver_order": 2,
|
||||
"solver_type": "midpoint",
|
||||
"thresholding": False,
|
||||
}
|
||||
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
StableCascadeUNet,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
||||
@@ -143,7 +143,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen'
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
@@ -440,7 +440,6 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
self.ff_in = FeedForward(
|
||||
dim,
|
||||
dim_out=time_mix_inner_dim,
|
||||
|
||||
@@ -124,9 +124,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,13 +92,24 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
interpolation_scale: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
if patch_size is not None:
|
||||
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
|
||||
raise NotImplementedError(
|
||||
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
||||
)
|
||||
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
||||
)
|
||||
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
@@ -168,8 +179,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
interpolation_scale = (
|
||||
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
|
||||
)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
|
||||
@@ -10,6 +10,7 @@ if is_torch_available():
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .unet_stable_cascade import StableCascadeUNet
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
|
||||
@@ -99,9 +99,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
|
||||
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The tuple of downsample blocks to use.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# mid
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.config.mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
|
||||
|
||||
# up
|
||||
up_blocks = []
|
||||
@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
@@ -90,7 +90,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
|
||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
||||
num_frames: int = 25,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -0,0 +1,609 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm
|
||||
class SDCascadeLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = super().forward(x)
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class SDCascadeTimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=[]):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
|
||||
|
||||
class SDCascadeResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c + c_skip, c * 4),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(c * 4, c),
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * stand_div_norm) + self.beta + x
|
||||
|
||||
|
||||
class SDCascadeAttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
norm_x = self.norm(x)
|
||||
if self.self_attn:
|
||||
batch_size, channel, _, _ = x.shape
|
||||
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
|
||||
x = x + self.attention(norm_x, encoder_hidden_states=kv)
|
||||
return x
|
||||
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mode, enabled=True):
|
||||
super().__init__()
|
||||
if mode not in ["up", "down"]:
|
||||
raise ValueError(f"{mode} not supported")
|
||||
interpolation = (
|
||||
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True)
|
||||
if enabled
|
||||
else nn.Identity()
|
||||
)
|
||||
mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableCascadeUNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
),
|
||||
up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1),
|
||||
block_types_per_layer: Tuple[Tuple[str]] = (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
),
|
||||
clip_text_in_channels: Optional[int] = None,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels: Optional[int] = None,
|
||||
clip_seq=4,
|
||||
effnet_in_channels: Optional[int] = None,
|
||||
pixel_mapper_in_channels: Optional[int] = None,
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`, defaults to 16):
|
||||
Number of channels in the input sample.
|
||||
out_channels (`int`, defaults to 16):
|
||||
Number of channels in the output sample.
|
||||
timestep_ratio_embedding_dim (`int`, defaults to 64):
|
||||
Dimension of the projected time embedding.
|
||||
patch_size (`int`, defaults to 1):
|
||||
Patch size to use for pixel unshuffling layer
|
||||
conditioning_dim (`int`, defaults to 2048):
|
||||
Dimension of the image and text conditional embedding.
|
||||
block_out_channels (Tuple[int], defaults to (2048, 2048)):
|
||||
Tuple of output channels for each block.
|
||||
num_attention_heads (Tuple[int], defaults to (32, 32)):
|
||||
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
|
||||
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
|
||||
Number of layers in each down block.
|
||||
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
|
||||
Number of layers in each up block.
|
||||
down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each down block.
|
||||
up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each up block.
|
||||
block_types_per_layer (Tuple[Tuple[str]], optional,
|
||||
defaults to (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
|
||||
):
|
||||
Block types used in each layer of the up/down blocks.
|
||||
clip_text_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for CLIP based text conditioning.
|
||||
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
|
||||
Number of input channels for pooled CLIP text embeddings.
|
||||
clip_image_in_channels (`int`, *optional*):
|
||||
Number of input channels for CLIP based image conditioning.
|
||||
clip_seq (`int`, *optional*, defaults to 4):
|
||||
effnet_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for effnet conditioning.
|
||||
pixel_mapper_in_channels (`int`, defaults to `None`):
|
||||
Number of input channels for pixel mapper conditioning.
|
||||
kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size to use in the block convolutional layers.
|
||||
dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)):
|
||||
Dropout to use per block.
|
||||
self_attn (Union[bool, Tuple[bool]]):
|
||||
Tuple of booleans that determine whether to use self attention in a block or not.
|
||||
timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")):
|
||||
Timestep conditioning type.
|
||||
switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`):
|
||||
Tuple that indicates whether upsampling or downsampling should be applied in a block
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if len(block_out_channels) != len(down_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(down_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(block_types_per_layer):
|
||||
raise ValueError(
|
||||
f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
if isinstance(dropout, float):
|
||||
dropout = (dropout,) * len(block_out_channels)
|
||||
if isinstance(self_attn, bool):
|
||||
self_attn = (self_attn,) * len(block_out_channels)
|
||||
|
||||
# CONDITIONING
|
||||
if effnet_in_channels is not None:
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
if pixel_mapper_in_channels is not None:
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq)
|
||||
if clip_text_in_channels is not None:
|
||||
self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim)
|
||||
if clip_image_in_channels is not None:
|
||||
self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq)
|
||||
self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == "SDCascadeResBlock":
|
||||
return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout)
|
||||
elif block_type == "SDCascadeAttnBlock":
|
||||
return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout)
|
||||
elif block_type == "SDCascadeTimestepBlock":
|
||||
return SDCascadeTimestepBlock(
|
||||
in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block type {block_type} not supported")
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(block_out_channels)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(down_num_layers_per_block[i]):
|
||||
for block_type in block_types_per_layer[i]:
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if down_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(down_blocks_repeat_mappers[i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(block_out_channels))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.ConvTranspose2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(up_num_layers_per_block[::-1][i]):
|
||||
for k, block_type in enumerate(block_types_per_layer[i]):
|
||||
c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
c_skip=c_skip,
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
if up_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(up_blocks_repeat_mappers[::-1][i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02)
|
||||
nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None
|
||||
nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None
|
||||
|
||||
if hasattr(self, "effnet_mapper"):
|
||||
nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
|
||||
# blocks
|
||||
for level_block in self.down_blocks + self.up_blocks:
|
||||
for block in level_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0]))
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
nn.init.constant_(block.mapper.weight, 0)
|
||||
|
||||
def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000):
|
||||
r = timestep_ratio * max_positions
|
||||
half_dim = self.config.timestep_ratio_embedding_dim // 2
|
||||
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
|
||||
if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
||||
|
||||
return emb.to(dtype=r.dtype)
|
||||
|
||||
def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None):
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pool = clip_txt_pooled.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
|
||||
clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
if clip_txt is not None and clip_img is not None:
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(
|
||||
clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
else:
|
||||
clip = clip_txt_pool
|
||||
return self.clip_norm(clip)
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), use_reentrant=False
|
||||
)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
else:
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = block(x)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, skip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
else:
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample,
|
||||
timestep_ratio,
|
||||
clip_text_pooled,
|
||||
clip_text=None,
|
||||
clip_img=None,
|
||||
effnet=None,
|
||||
pixels=None,
|
||||
sca=None,
|
||||
crp=None,
|
||||
return_dict=True,
|
||||
):
|
||||
if pixels is None:
|
||||
pixels = sample.new_zeros(sample.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio)
|
||||
for c in self.config.timestep_conditioning_type:
|
||||
if c == "sca":
|
||||
cond = sca
|
||||
elif c == "crp":
|
||||
cond = crp
|
||||
else:
|
||||
cond = None
|
||||
t_cond = cond or torch.zeros_like(timestep_ratio)
|
||||
timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1)
|
||||
clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(sample)
|
||||
if hasattr(self, "effnet_mapper") and effnet is not None:
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
|
||||
)
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
x = x + nn.functional.interpolate(
|
||||
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
|
||||
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
|
||||
sample = self.clf(x)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
return StableCascadeUNetOutput(sample=sample)
|
||||
@@ -11,6 +11,7 @@ from ..utils import (
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
@@ -176,6 +177,11 @@ else:
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_cascade"] = [
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
@@ -424,6 +430,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pixart_alpha import PixArtAlphaPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_cascade import (
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
|
||||
@@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -400,15 +400,22 @@ class AnimateDiffPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -509,9 +516,9 @@ class AnimateDiffPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
@@ -661,8 +668,8 @@ class AnimateDiffPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -783,6 +790,8 @@ class AnimateDiffPipeline(
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
@@ -822,13 +831,14 @@ class AnimateDiffPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -478,15 +478,22 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -589,9 +596,9 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
||||
@@ -821,8 +828,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -935,6 +942,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -973,15 +981,11 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 9. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -343,7 +343,7 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
additional memory.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
@@ -616,7 +616,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
additional memory.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
@@ -892,7 +892,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
additional memory.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
|
||||
@@ -510,15 +510,22 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -726,9 +733,9 @@ class StableDiffusionControlNetPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
def check_image(self, image, prompt, prompt_embeds):
|
||||
|
||||
@@ -503,15 +503,22 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -713,9 +720,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
|
||||
@@ -628,15 +628,22 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -871,9 +878,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
|
||||
@@ -537,15 +537,22 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -817,9 +824,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
def prepare_control_image(
|
||||
|
||||
@@ -515,15 +515,22 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -730,9 +737,9 @@ class StableDiffusionXLControlNetPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
||||
@@ -909,6 +916,10 @@ class StableDiffusionXLControlNetPipeline(
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def denoising_end(self):
|
||||
return self._denoising_end
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -923,6 +934,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
@@ -982,6 +994,13 @@ class StableDiffusionXLControlNetPipeline(
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
@@ -1144,6 +1163,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1318,6 +1338,23 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8.1 Apply denoising_end
|
||||
if (
|
||||
self.denoising_end is not None
|
||||
and isinstance(self.denoising_end, float)
|
||||
and self.denoising_end > 0
|
||||
and self.denoising_end < 1
|
||||
):
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
@@ -1423,7 +1460,22 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
|
||||
@@ -567,15 +567,22 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -794,9 +801,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
|
||||
@@ -1580,7 +1587,22 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
|
||||
@@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -726,13 +726,14 @@ class I2VGenXLPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 8. Post processing
|
||||
if output_type == "latent":
|
||||
return I2VGenXLPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
+12
-5
@@ -453,15 +453,22 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -647,9 +654,9 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
+12
-5
@@ -437,15 +437,22 @@ class LatentConsistencyModelPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -579,9 +586,9 @@ class LatentConsistencyModelPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -582,9 +582,9 @@ class PIAPipeline(
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
||||
)
|
||||
elif ip_adapter_image_embeds[0].ndim != 3:
|
||||
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
@@ -619,15 +619,22 @@ class PIAPipeline(
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
else:
|
||||
repeat_dims = [1]
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
@@ -853,8 +860,8 @@ class PIAPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
@@ -1011,13 +1018,14 @@ class PIAPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return PIAPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -0,0 +1,508 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
model_info,
|
||||
)
|
||||
from packaging import version
|
||||
|
||||
from ..utils import (
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_class_from_dynamic_module,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
||||
CONNECTED_PIPES_KEYS = ["prior"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"onnxruntime.training": {
|
||||
"ORTModule": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
files to know which safetensors files are needed.
|
||||
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||
|
||||
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
pt_filenames = []
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
passed_components = passed_components or []
|
||||
|
||||
for filename in filenames:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
||||
continue
|
||||
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(os.path.normpath(filename))
|
||||
elif extension == ".safetensors":
|
||||
sf_filenames.add(os.path.normpath(filename))
|
||||
|
||||
for filename in pt_filenames:
|
||||
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
if filename.startswith("pytorch_model"):
|
||||
filename = filename.replace("pytorch_model", "model")
|
||||
else:
|
||||
filename = filename
|
||||
|
||||
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
||||
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||
if expected_sf_filename not in sf_filenames:
|
||||
logger.warning(f"{expected_sf_filename} not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
token=token,
|
||||
revision=None,
|
||||
)
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
||||
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
||||
|
||||
if set(model_filenames).issubset(set(comp_model_filenames)):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap_model(model):
|
||||
"""Unwraps a model."""
|
||||
if is_compiled_module(model):
|
||||
model = model._orig_mod
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
):
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
sub_model = passed_class_obj[name]
|
||||
unwrapped_sub_model = _unwrap_model(sub_model)
|
||||
model_cls = unwrapped_sub_model.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
return class_obj, class_candidates
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
path = Path(custom_pipeline)
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
if repo_id is not None and hub_revision is not None:
|
||||
# if we load the pipeline code from the Hub
|
||||
# make sure to overwrite the `revision`
|
||||
revision = hub_revision
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
if load_connected_pipeline:
|
||||
from .auto_pipeline import _get_connected_pipeline
|
||||
|
||||
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
|
||||
if connected_pipeline_cls is not None:
|
||||
logger.info(
|
||||
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
|
||||
|
||||
pipeline_cls = connected_pipeline_cls or pipeline_cls
|
||||
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
pipeline_class: Any,
|
||||
torch_dtype: torch.dtype,
|
||||
provider: Any,
|
||||
sess_options: Any,
|
||||
device_map: Optional[Union[Dict[str, torch.device], str]],
|
||||
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
|
||||
offload_folder: Optional[Union[str, os.PathLike]],
|
||||
offload_state_dict: bool,
|
||||
model_variants: Dict[str, str],
|
||||
name: str,
|
||||
from_flax: bool,
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
# retrieve load method name
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
# if load method name is None, then we have a dummy module -> raise Error
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||
)
|
||||
if is_dummy_path and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["max_memory"] = max_memory
|
||||
loading_kwargs["offload_folder"] = offload_folder
|
||||
loading_kwargs["offload_state_dict"] = offload_state_dict
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _fetch_class_library_tuple(module):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# register the config from the original module, not the dynamo compiled one
|
||||
not_compiled_module = _unwrap_model(module)
|
||||
library = not_compiled_module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
module_path_items = not_compiled_module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = not_compiled_module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
@@ -19,7 +19,6 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
@@ -49,16 +48,13 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
PushToHubMixin,
|
||||
deprecate,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_peft_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
@@ -66,55 +62,37 @@ from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
if is_torch_npu_available():
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
|
||||
|
||||
from .pipeline_loading_utils import (
|
||||
ALL_IMPORTABLE_CLASSES,
|
||||
CONNECTED_PIPES_KEYS,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_get_pipeline_class,
|
||||
_unwrap_model,
|
||||
is_safetensors_compatible,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
variant_compatible_siblings,
|
||||
warn_deprecated_model_variant,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
||||
CONNECTED_PIPES_KEYS = ["prior"]
|
||||
|
||||
LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"onnxruntime.training": {
|
||||
"ORTModule": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
@@ -142,432 +120,6 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
files to know which safetensors files are needed.
|
||||
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||
|
||||
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
pt_filenames = []
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
passed_components = passed_components or []
|
||||
|
||||
for filename in filenames:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
||||
continue
|
||||
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(os.path.normpath(filename))
|
||||
elif extension == ".safetensors":
|
||||
sf_filenames.add(os.path.normpath(filename))
|
||||
|
||||
for filename in pt_filenames:
|
||||
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
if filename.startswith("pytorch_model"):
|
||||
filename = filename.replace("pytorch_model", "model")
|
||||
else:
|
||||
filename = filename
|
||||
|
||||
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
||||
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||
if expected_sf_filename not in sf_filenames:
|
||||
logger.warning(f"{expected_sf_filename} not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
token=token,
|
||||
revision=None,
|
||||
)
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
||||
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
||||
|
||||
if set(model_filenames).issubset(set(comp_model_filenames)):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap_model(model):
|
||||
"""Unwraps a model."""
|
||||
if is_compiled_module(model):
|
||||
model = model._orig_mod
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
):
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
sub_model = passed_class_obj[name]
|
||||
unwrapped_sub_model = _unwrap_model(sub_model)
|
||||
model_cls = unwrapped_sub_model.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
return class_obj, class_candidates
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
path = Path(custom_pipeline)
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
if repo_id is not None and hub_revision is not None:
|
||||
# if we load the pipeline code from the Hub
|
||||
# make sure to overwrite the `revision`
|
||||
revision = hub_revision
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj != DiffusionPipeline:
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
if load_connected_pipeline:
|
||||
from .auto_pipeline import _get_connected_pipeline
|
||||
|
||||
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
|
||||
if connected_pipeline_cls is not None:
|
||||
logger.info(
|
||||
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
|
||||
|
||||
pipeline_cls = connected_pipeline_cls or pipeline_cls
|
||||
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
pipeline_class: Any,
|
||||
torch_dtype: torch.dtype,
|
||||
provider: Any,
|
||||
sess_options: Any,
|
||||
device_map: Optional[Union[Dict[str, torch.device], str]],
|
||||
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
|
||||
offload_folder: Optional[Union[str, os.PathLike]],
|
||||
offload_state_dict: bool,
|
||||
model_variants: Dict[str, str],
|
||||
name: str,
|
||||
from_flax: bool,
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
# retrieve load method name
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
# if load method name is None, then we have a dummy module -> raise Error
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||
)
|
||||
if is_dummy_path and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["max_memory"] = max_memory
|
||||
loading_kwargs["offload_folder"] = offload_folder
|
||||
loading_kwargs["offload_state_dict"] = offload_state_dict
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _fetch_class_library_tuple(module):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# register the config from the original module, not the dynamo compiled one
|
||||
not_compiled_module = _unwrap_model(module)
|
||||
library = not_compiled_module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
module_path_items = not_compiled_module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = not_compiled_module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all pipelines.
|
||||
@@ -1079,6 +631,33 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
@@ -1211,33 +790,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
|
||||
@@ -133,6 +133,42 @@ ASPECT_RATIO_512_BIN = {
|
||||
"4.0": [1024.0, 256.0],
|
||||
}
|
||||
|
||||
ASPECT_RATIO_256_BIN = {
|
||||
"0.25": [128.0, 512.0],
|
||||
"0.28": [128.0, 464.0],
|
||||
"0.32": [144.0, 448.0],
|
||||
"0.33": [144.0, 432.0],
|
||||
"0.35": [144.0, 416.0],
|
||||
"0.4": [160.0, 400.0],
|
||||
"0.42": [160.0, 384.0],
|
||||
"0.48": [176.0, 368.0],
|
||||
"0.5": [176.0, 352.0],
|
||||
"0.52": [176.0, 336.0],
|
||||
"0.57": [192.0, 336.0],
|
||||
"0.6": [192.0, 320.0],
|
||||
"0.68": [208.0, 304.0],
|
||||
"0.72": [208.0, 288.0],
|
||||
"0.78": [224.0, 288.0],
|
||||
"0.82": [224.0, 272.0],
|
||||
"0.88": [240.0, 272.0],
|
||||
"0.94": [240.0, 256.0],
|
||||
"1.0": [256.0, 256.0],
|
||||
"1.07": [256.0, 240.0],
|
||||
"1.13": [272.0, 240.0],
|
||||
"1.21": [272.0, 224.0],
|
||||
"1.29": [288.0, 224.0],
|
||||
"1.38": [288.0, 208.0],
|
||||
"1.46": [304.0, 208.0],
|
||||
"1.67": [320.0, 192.0],
|
||||
"1.75": [336.0, 192.0],
|
||||
"2.0": [352.0, 176.0],
|
||||
"2.09": [368.0, 176.0],
|
||||
"2.4": [384.0, 160.0],
|
||||
"2.5": [400.0, 160.0],
|
||||
"3.0": [432.0, 144.0],
|
||||
"4.0": [512.0, 128.0],
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
@@ -260,6 +296,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -284,8 +321,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
|
||||
if "mask_feature" in kwargs:
|
||||
@@ -303,7 +341,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = 120
|
||||
max_length = max_sequence_length
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
@@ -688,6 +726,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
@@ -757,6 +796,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||
the requested resolution. Useful for generating non-square images.
|
||||
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -772,9 +812,14 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
aspect_ratio_bin = (
|
||||
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
|
||||
)
|
||||
if self.transformer.config.sample_size == 128:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
elif self.transformer.config.sample_size == 64:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
elif self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
@@ -822,6 +867,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
+1
@@ -136,6 +136,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,482 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import is_torch_version, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
|
||||
|
||||
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
|
||||
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
|
||||
... ).to("cuda")
|
||||
>>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain(
|
||||
... "stabilityai/stable-cascade", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Stable Cascade model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The CLIP tokenizer.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The CLIP text encoder.
|
||||
decoder ([`StableCascadeUNet`]):
|
||||
The Stable Cascade decoder unet.
|
||||
vqgan ([`PaellaVQModel`]):
|
||||
The VQGAN model.
|
||||
scheduler ([`DDPMWuerstchenScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
latent_dim_scale (float, `optional`, defaults to 10.67):
|
||||
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
|
||||
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
|
||||
width=int(24*10.67)=256 in order to match the training conditions.
|
||||
"""
|
||||
|
||||
unet_name = "decoder"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds_pooled",
|
||||
"negative_prompt_embeds",
|
||||
"image_embeddings",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: StableCascadeUNet,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
latent_dim_scale: float = 10.67,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
decoder=decoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||
|
||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
||||
batch_size, channels, height, width = image_embeddings.shape
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
int(height * self.config.latent_dim_scale),
|
||||
int(width * self.config.latent_dim_scale),
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
device,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = text_encoder_output.hidden_states[-1]
|
||||
if prompt_embeds_pooled is None:
|
||||
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=uncond_input.attention_mask.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
seq_len = negative_prompt_embeds_pooled.shape[1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
|
||||
dtype=self.text_encoder.dtype, device=device
|
||||
)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
# done duplicates
|
||||
|
||||
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 10,
|
||||
guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
Image Embeddings either extracted from an image or generated by a Prior Model.
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
|
||||
linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
|
||||
argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
|
||||
embeddings.
|
||||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
device = self._execution_device
|
||||
dtype = self.decoder.dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
|
||||
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
if isinstance(image_embeddings, list):
|
||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||
batch_size = image_embeddings.shape[0]
|
||||
|
||||
# 2. Encode caption
|
||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
# The pooled embeds from the prior are pooled again before being passed to the decoder
|
||||
prompt_embeds_pooled = (
|
||||
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
|
||||
if self.do_classifier_free_guidance
|
||||
else prompt_embeds_pooled
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
# 6. Run denoising loop
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
|
||||
clip_text_pooled=prompt_embeds_pooled,
|
||||
effnet=effnet,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 8. Check for classifier free guidance and apply it
|
||||
if self.do_classifier_free_guidance:
|
||||
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
|
||||
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_latents,
|
||||
timestep=timestep_ratio,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if not output_type == "latent":
|
||||
# 10. Scale and decode the image latents with vq-vae
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
images = self.numpy_to_pil(images)
|
||||
else:
|
||||
images = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return images
|
||||
return ImagePipelineOutput(images)
|
||||
@@ -0,0 +1,306 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
|
||||
|
||||
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusions import StableCascadeCombinedPipeline
|
||||
|
||||
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> images = pipe(prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Stable Cascade.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The decoder tokenizer to be used for text inputs.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The decoder text encoder to be used for text inputs.
|
||||
decoder (`StableCascadeUNet`):
|
||||
The decoder model to be used for decoder image generation pipeline.
|
||||
scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for decoder image generation pipeline.
|
||||
vqgan (`PaellaVQModel`):
|
||||
The VQGAN model to be used for decoder image generation pipeline.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
prior_prior (`StableCascadeUNet`):
|
||||
The prior model to be used for prior pipeline.
|
||||
prior_scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for prior pipeline.
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
decoder: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
prior_prior: StableCascadeUNet,
|
||||
prior_text_encoder: CLIPTextModel,
|
||||
prior_tokenizer: CLIPTokenizer,
|
||||
prior_scheduler: DDPMWuerstchenScheduler,
|
||||
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
prior_text_encoder=prior_text_encoder,
|
||||
prior_tokenizer=prior_tokenizer,
|
||||
prior_prior=prior_prior,
|
||||
prior_scheduler=prior_scheduler,
|
||||
prior_feature_extractor=prior_feature_extractor,
|
||||
prior_image_encoder=prior_image_encoder,
|
||||
)
|
||||
self.prior_pipe = StableCascadePriorPipeline(
|
||||
prior=prior_prior,
|
||||
text_encoder=prior_text_encoder,
|
||||
tokenizer=prior_tokenizer,
|
||||
scheduler=prior_scheduler,
|
||||
image_encoder=prior_image_encoder,
|
||||
feature_extractor=prior_feature_extractor,
|
||||
)
|
||||
self.decoder_pipe = StableCascadeDecoderPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self.prior_pipe.set_progress_bar_config(**kwargs)
|
||||
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prior_num_inference_steps: int = 60,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
num_inference_steps: int = 12,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation for the prior and decoder.
|
||||
images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*):
|
||||
The images to guide the image generation for the prior.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
|
||||
to the text `prompt`, usually at the expense of lower image quality.
|
||||
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
|
||||
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`prior_timesteps`
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
|
||||
the expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`timesteps`
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
prior_callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
|
||||
int, callback_kwargs: Dict)`.
|
||||
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
||||
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
||||
the `._callback_tensor_inputs` attribute of your pipeine class.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
images=images,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=prior_num_inference_steps,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type="pt",
|
||||
return_dict=True,
|
||||
callback_on_step_end=prior_callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
image_embeddings = prior_outputs.image_embeddings
|
||||
prompt_embeds = prior_outputs.get("prompt_embeds", None)
|
||||
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
|
||||
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
|
||||
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
|
||||
|
||||
outputs = self.decoder_pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=decoder_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
generator=generator,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user